import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2
import torch
import torchvision.transforms as tfms
from transformers import (
CLIPTokenizer,
CLIPTextModel,
)from diffusers import (
DDPMScheduler,
AutoencoderKL,
UNet2DConditionModel )
Implementing DiffEdit: Editing Images with Text Prompts
Overview and Setup
The goal of this post is to implement DiffEdit, an algorithm that edits images using text prompts, proposed in the paper DiffEdit: Diffusion-based semantic image editing with mask guidance. The following pair of images illustrates what the algorithm can do. The image on the left is a real image from the USDA’s SNAP-Ed page on blackberries. The image on the right was generated from the real image using DiffEdit and the indicated text prompts.
Thanks go to Hugging Face for its wonderful diffusers library used in this post and for its tutorials on how to use the library, which helped me a great deal. Thanks also go to Jonathan Whitaker for his very helpful deep dive notebook on Stable Diffusion, which I used as a source when creating this post. Finally, thanks go to the authors of the DiffEdit paper, Guillaume Couairon, Jakob Verbeek, Holger Schwenk, and Matthieu Cord, for developing the algorithm and for the clear writing and explanations in their paper.
Package | Version |
---|---|
python | 3.9.16 |
PyTorch | 2.2.2+cu121 |
torchvision | 0.17.2+cu121 |
matplotlib | 3.8.4 |
numpy | 1.26.4 |
PIL | 10.3.0 |
transformers | 4.39.3 |
diffusers | 0.27.2 |
Outline of DiffEdit
The DiffEdit algorithm has three steps.
- Step 1: Generate a Mask
- Generate a binary mask highlighting the portion of the image that we want to edit.
- Step 2: Add Noise
- Add noise to the image using an inverse DDIM process as described in the paper Denoising Diffusion Implicit Models.
- Step 3: Carefully Remove the Noise
- Remove the noise from the final noised image using a DDIM process as described in the same paper above, modified to make the edit that we want.
Step 1: Generate a Mask
First, let’s load our image.
from diffusers.utils import load_image
= "usda_images/blackberries.jpg"
image_path = load_image(image_path)
sample_image sample_image
We’ll generate our mask by adding noise to the image and then estimating the noise in two ways: once conditioned on a text prompt that accurately describes the image and once conditioned on a text prompt that reflects the edit we want to make. First, let’s load all of the models we’ll need.
- variational auto-encoder
- scheduler
- tokenizer
- text encoder
- U-Net
These components have all been pre-trained to work well together.
= (
device "cuda" if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available()
else "cpu"
)
= "CompVis/stable-diffusion-v1-4"
checkpoint
= AutoencoderKL.from_pretrained(
vae ="vae", use_safetensors=True,
checkpoint, subfolder
).to(device)
= DDPMScheduler.from_pretrained(
scheduler ="scheduler",
checkpoint, subfolder
)
= CLIPTokenizer.from_pretrained(
tokenizer ="tokenizer",
checkpoint, subfolder
)
= CLIPTextModel.from_pretrained(
text_encoder ="text_encoder", use_safetensors=True,
checkpoint, subfolder
).to(device)
= UNet2DConditionModel.from_pretrained(
unet ="unet", use_safetensors=True,
checkpoint, subfolder ).to(device)
Before we do anything, let’s set the number of timesteps we want to use for our DDIM and inverse DDIM processes. With the scheduler we’re using, the number of timesteps can be anywhere from \(1\) to \(1000\). The more timesteps we use, the better our outputs will be, but the longer it will take to generate those outputs. We’ll use \(100\) timesteps, which will generate good results without taking too long.
100) scheduler.set_timesteps(
The editing process doesn’t happen on the original image. Instead, we work with a compressed version of the image called its latent representation. The latent image is much smaller than the original, so the editing process is much faster. The following utility functions convert back and forth between an image and its latent representation.
def image2latent(image, vae):
= tfms.ToTensor()(image)[None].to(device)
image with torch.no_grad():
= vae.encode(image * 2 - 1)
latent return 0.18215 * latent.latent_dist.sample()
def latent2image(latent, vae):
= latent / 0.18215
latent with torch.no_grad():
= vae.decode(latent).sample
image = (image + 1) / 2
image = image.clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
image = (image * 255).round().astype(np.uint8)
image = Image.fromarray(image[0])
image return image
We’ll add noise to the image (really, to its latent representation) using our scheduler’s add_noise()
method.
def add_noise_at_strength(latent, scheduler, strength=0.5):
= len(scheduler.timesteps)
num_timesteps = num_timesteps - int(num_timesteps * strength)
target_timestep = torch.randn_like(latent)
noise = scheduler.add_noise(
noisy_latent None]
latent, noise, scheduler.timesteps[target_timestep][
)return noisy_latent, target_timestep
The strength
parameter in the function can vary from \(0\) to \(1\) and controls how much noise we add. Setting it to \(0\) adds no noise, while setting it to \(1\) replaces the entire latent with pure noise. The DiffEdit paper recommends adding noise at strength \(0.5\). Let’s see what that looks like when we convert the noised latent back to an image.
= image2latent(sample_image, vae)
sample_latent = add_noise_at_strength(
sample_noisy_latent, sample_timestep =0.5
sample_latent, scheduler, strength
) latent2image(sample_noisy_latent, vae)
Now we’ll define all of the other functions we need to generate the mask. We need to
- embed the text prompts
- predict the noise
- predict the difference in noise, averaged over 10 attempts for stability
- generate the mask itself
def embed_prompt(prompt, tokenizer, text_encoder):
# tokenize the input
= tokenizer(
text_input
prompt,="max_length",
padding=tokenizer.model_max_length,
max_length=True,
truncation="pt",
return_tensors
).to(device)
# convert the tokenized text to embeddings
with torch.no_grad():
= text_encoder(
text_embedding
text_input.input_ids
).last_hidden_state
# create the "unconditional embedding"
= text_input.input_ids.shape[-1]
max_length = tokenizer(
unconditional_input ""],
[="max_length",
padding=max_length,
max_length="pt"
return_tensors
).to(device)with torch.no_grad():
= text_encoder(
unconditional_embedding
unconditional_input.input_ids
).last_hidden_state
# concatenate the embeddings for classifier-free guidance
= torch.cat(
text_embedding
[unconditional_embedding, text_embedding]
)
return text_embedding
def predict_noise(
latent,
embedded_prompt,
scheduler,
timestep,
unet,=7.5,
guidance_scale
):# predict the noise
= scheduler.timesteps[timestep]
t = torch.cat([latent] * 2)
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
latent_model_input with torch.no_grad():
= unet(
noise_pred =embedded_prompt
latent_model_input, t, encoder_hidden_states"sample"]
)[
# do classifier-free guidance
= noise_pred.chunk(2)
noise_pred_unconditional, noise_pred_prompt = (
noise_pred
noise_pred_unconditional+ guidance_scale * (noise_pred_prompt - noise_pred_unconditional)
)
return noise_pred
def compute_average_noise_difference(
latent,
embedded_prompt_1,
embedded_prompt_2,
scheduler,
timestep,
unet,=7.5,
guidance_scale=10,
num_iterations
):= []
result for _ in range(num_iterations):
# predict noise for both prompts
= predict_noise(
noise_1
latent, embedded_prompt_1, scheduler, timestep, unet, guidance_scale
)= predict_noise(
noise_2
latent, embedded_prompt_2, scheduler, timestep, unet, guidance_scale
)
# compute the average absolute difference over latent channels
= (noise_1 - noise_2).abs().mean(dim=1)
noise_diff
# remove extreme noise predictions
= noise_diff.mean() + noise_diff.std()
upper_bound = noise_diff.clamp_max(upper_bound)
noise_diff
result.append(noise_diff)
# average the absolute differences for stability
= torch.cat(result).mean(dim=0)
result
# rescale to the interval [0, 1]
= (result - result.min()) / (result.max() - result.min())
result
return result
def generate_mask(
image,
vae,
prompt_1,
prompt_2,
scheduler,
tokenizer,
text_encoder,
unet,=10,
num_iterations=7.5,
guidance_scale=0.5,
strength
):# convert image to YUV format to do histogram equalization
# not in DiffEdit paper, but helps with objects in shadows
= np.array(image)
image = cv2.cvtColor(image, cv2.COLOR_RGB2YUV)
image
# do histogram equalization
= cv2.createCLAHE(clipLimit=4.0, tileGridSize=(8, 8))
clahe 0] = clahe.apply(image[..., 0])
image[...,
# convert image back to RGB format
= cv2.cvtColor(image, cv2.COLOR_YUV2RGB)
image = Image.fromarray(image)
image
# add noise
= image2latent(image, vae)
latent = add_noise_at_strength(latent, scheduler, strength=strength)
latent, timestep
# compute the mask
= embed_prompt(prompt_1, tokenizer, text_encoder)
embedded_prompt_1 = embed_prompt(prompt_2, tokenizer, text_encoder)
embedded_prompt_2 = compute_average_noise_difference(
noise_diff
latent,
embedded_prompt_1,
embedded_prompt_2,
scheduler,
timestep,
unet,
guidance_scale,
num_iterations
)= postprocess_and_binarize(noise_diff)
mask
return mask
def postprocess_and_binarize(noise_diff, filter_size=15, threshold=0.3):
# apply Gaussian blur to improve the quality of the mask
= noise_diff.cpu().numpy()
noise_diff = cv2.GaussianBlur(noise_diff, (filter_size, filter_size), 0)
noise_diff
# binarize and put on GPU
= torch.from_numpy(noise_diff > threshold).long()
mask = mask.to(device)
mask
return mask
Now we’ll finally generate our mask. Following the DiffEdit paper, we’ll call the prompt that accurately descsribes the image the reference text and the prompt that reflects the edits we want to make the query.
= "A bowl of blackberries."
reference_text = "A bowl of blueberries."
query
= generate_mask(
sample_mask
sample_image,
vae,
reference_text,
query,
scheduler,
tokenizer,
text_encoder,
unet,=7.5,
guidance_scale=0.5,
strength=10,
num_iterations )
Let’s see how we did.
= (sample_mask.cpu().numpy().astype(np.uint8)) * 255
resized_mask = Image.fromarray(resized_mask)
resized_mask = resized_mask.resize(sample_image.size)
resized_mask
= plt.subplots(nrows=1, ncols=2, figsize=(10, 5))
fig, axs 0].imshow(sample_image)
axs[1].imshow(sample_image)
axs[1].imshow(resized_mask, cmap="Greys_r", alpha=0.5)
axs[0].set_axis_off()
axs[1].set_axis_off()
axs[0].set_title("Original Image")
axs[1].set_title("Image with Mask Overlay")
axs[
fig.tight_layout() plt.show()
Step 2: Add Noise
Now we’ll add noise to our image with an inverse DDIM process. We’ll also save the intermediate noised latents for use in Step 3, where they’ll help keep the edited image close to the original outside of the mask.
The inverse DDIM process uses the following equation (Equation 2 in the DiffEdit paper). I’ve changed the notation slightly to make it clear that it can be used to both add and remove noise.
\[ \mathbf x_{t + \Delta t} = \sqrt{\alpha_{t + \Delta t}} \left( \frac{\mathbf x_t - \sqrt{1 - \alpha_t} \epsilon_\theta(\mathbf x_t, t)}{\sqrt{\alpha_t}} \right) + \sqrt{1 - \alpha_{t + \Delta t}} \; \epsilon_\theta(\mathbf x_t, t)\]
Here, \(t\) is a timestep in our inverse DDIM process, which we view as the current timestep, \(t + \Delta t\) is the next timestep in the process, \(\mathbf x_t\) is the partially noised latent at the current timestep, \(x_{t + \Delta t}\) is the partially noised latent that we want to calculate for the next timestep, \(\alpha_t\) is a coefficient that represents the amount of noise that should be in the latent at timestep \(t\), and \(\epsilon_\theta(\mathbf x_t, t)\) is the estimated noise in the latent at timestep \(t\).
The timesteps \(t\) and coefficients \(\alpha_t\) are built into our scheduler object. We’ll compute the predicted noise \(\epsilon_\theta\) using our pre-trained U-Net.
Now let’s re-write the equation to simplify it a bit. First, we can rearrange it to get
\[ \frac{\mathbf x_{t + \Delta t}}{\sqrt{\alpha_{t + \Delta t}}} = \frac{\mathbf x_t}{\sqrt{\alpha_t}} + \left( \sqrt{\frac{1 - \alpha_{t + \Delta t}}{\alpha_{t + \Delta t}}} - \sqrt{\frac{1 - \alpha_t}{\alpha_t}} \right) \epsilon_\theta(\mathbf x_t, t).\]
Then we can set \(\beta_t = 1 - \alpha_t\) to get
\[ \frac{\mathbf x_{t + \Delta t}}{\sqrt{\alpha_{t + \Delta t}}} = \frac{\mathbf x_t}{\sqrt{\alpha_t}} + \left( \sqrt{\frac{\beta_{t + \Delta t}}{\alpha_{t + \Delta t}}} - \sqrt{\frac{\beta_t}{\alpha_t}} \right) \epsilon_\theta(\mathbf x_t, t).\]
Finally, we can set \(\gamma_t = \sqrt{\beta_t / \alpha_t}\) and multiply both sides by \(\sqrt{\alpha_{t + \Delta t}}\) to get
\[ \mathbf x_{t + \Delta t} = \sqrt{\frac{\alpha_{t + \Delta t}}{\alpha_t}} \, \mathbf x_t + \sqrt{\alpha_{t + \Delta t}} \left( \gamma_{t + \Delta t} - \gamma_t \right) \epsilon_\theta(\mathbf x_t, t),\]
which is the equation we’ll use below.
In the inverse DDIM process, \(\gamma_{t + \Delta t} - \gamma_t\) will be positive, so we’ll add noise to the latent at every step. In the regular DDIM process, the timesteps are reversed and \(\gamma_{t + \Delta t} - \gamma_t\) will be negative, so we’ll remove noise from the latent at every step.
def ddim_prep_timesteps(scheduler, encoding_ratio=0.5):
= scheduler.timesteps
timesteps = len(timesteps)
num_timesteps = num_timesteps - int(num_timesteps * encoding_ratio)
target_timestep = timesteps[target_timestep:]
timesteps return timesteps
def ddim_noising(
latent,
scheduler,
unet,
tokenizer,
text_encoder,=0.5,
encoding_ratio=7.5
guidance_scale
):= []
results
# embed the unconditional prompt and prep the timesteps
= embed_prompt("", tokenizer, text_encoder)
embedding = ddim_prep_timesteps(scheduler, encoding_ratio=encoding_ratio)
timesteps = reversed(timesteps)
timesteps
# prep the alphas and assemble a generator of consecutive pairs
= scheduler.alphas_cumprod[timesteps]
alphas = torch.cat([alphas.new_tensor([1.0]), alphas])
alphas = zip(alphas[:-1], alphas[1:])
alpha_pairs
for t, (alpha_t, alpha_t_plus_delta_t) in zip(timesteps, alpha_pairs):
# compute the unconditional noise
= torch.cat([latent] * 2)
latent_model_input with torch.no_grad():
= unet(
noise_pred =embedding
latent_model_input, t, encoder_hidden_states
).sample= noise_pred.chunk(2)
noise_pred_unconditional, noise_pred_prompt = (
noise_pred
noise_pred_unconditional+ guidance_scale * (noise_pred_prompt - noise_pred_unconditional)
)
# compute the coefficients in the DDIM equation
= 1 - alpha_t
beta_t = 1 - alpha_t_plus_delta_t
beta_t_plus_delta_t = (beta_t / alpha_t).sqrt()
gamma_t = (beta_t_plus_delta_t / alpha_t_plus_delta_t).sqrt()
gamma_t_plus_delta_t
# add noise to the latent using the DDIM equation
= (
latent / alpha_t).sqrt() * latent
(alpha_t_plus_delta_t + alpha_t_plus_delta_t.sqrt() * (gamma_t_plus_delta_t - gamma_t) * noise_pred
)
results.append(latent)
return results
= ddim_noising(
sample_noised_latents
sample_latent,
scheduler,
unet,
tokenizer,
text_encoder,=0.8,
encoding_ratio=7.5,
guidance_scale )
Let’s see what some of these partially noised latents look like.
= sample_noised_latents[9::10]
latents_to_show = [sample_latent] + latents_to_show
latents_to_show = plt.subplots(nrows=3, ncols=3, figsize=(14, 10))
fig, axs
for noised_latent, ax in zip(latents_to_show, axs.flat):
= latent2image(noised_latent, vae)
noised_image
ax.imshow(noised_image)
ax.set_axis_off()
fig.tight_layout() plt.show()
Step 3: Carefully Remove the Noise
Now we’ll remove the noise from our saved noised latents while conditioning on the query, which at each step will push the portion of the image highlighted by the mask towards the edit that we want. We’ll use the same equation for the DDIM process that we used for the inverse process, just reversing the sequence of timesteps \(t\). At each step, we’ll also replace the pixels not highlighted by the mask with the pixels from the corresponding partially noised image from Step 2, reducing the chance that the algorithm makes stray edits to the image.
def ddim_guided_denoising(
noised_latents,
mask,
prompt,
scheduler,
unet,
tokenizer,
text_encoder,=0.5,
encoding_ratio=7.5,
guidance_scale
):# prep the list of partially noised latents
# we'll need the original latent without noise for the last step
= list(reversed(noised_latents)) + [sample_latent]
noised_latents
# we start with the most noised latent
= noised_latents[0]
latent
# embed the query and prep the timesteps
= embed_prompt(prompt, tokenizer, text_encoder)
embedding = ddim_prep_timesteps(scheduler, encoding_ratio=encoding_ratio)
timesteps
# prep the alphas and assemble a generator of consecutive pairs
= scheduler.alphas_cumprod[timesteps]
alphas = torch.cat([alphas, alphas.new_tensor([1.0])])
alphas = zip(alphas[:-1], alphas[1:])
alpha_pairs
= zip(timesteps, noised_latents, alpha_pairs)
loop_data for t, prev_latent, (alpha_t, alpha_t_plus_delta_t) in loop_data:
# compute the noise conditioned on the query
= torch.cat([latent] * 2)
latent_model_input with torch.no_grad():
= unet(
noise_pred =embedding
latent_model_input, t, encoder_hidden_states
).sample= noise_pred.chunk(2)
noise_pred_unconditional, noise_pred_prompt = (
noise_pred
noise_pred_unconditional+ guidance_scale * (noise_pred_prompt - noise_pred_unconditional)
)
# compute the coefficients in the DDIM equation
= 1 - alpha_t
beta_t = 1 - alpha_t_plus_delta_t
beta_t_plus_delta_t = (beta_t / alpha_t).sqrt()
gamma_t = (beta_t_plus_delta_t / alpha_t_plus_delta_t).sqrt()
gamma_t_plus_delta_t
# remove noise from the latent using the DDIM equation
= (
latent / alpha_t).sqrt() * latent
(alpha_t_plus_delta_t + alpha_t_plus_delta_t.sqrt() * (gamma_t_plus_delta_t - gamma_t) * noise_pred
)
# replace pixels outside of the mask with those from the previous noised latent
# helps limit stray edits to the original image
= mask * latent + (1 - mask) * prev_latent
latent
return latent
= ddim_guided_denoising(
edited_sample_latent
sample_noised_latents,
sample_mask,
query,
scheduler,
unet,
tokenizer,
text_encoder,=0.8,
encoding_ratio=7.5,
guidance_scale )
Let’s see how we did.
= latent2image(edited_sample_latent, vae)
edited_sample_image
= plt.subplots(nrows=1, ncols=2, figsize=(15, 6))
fig, axs 0].imshow(sample_image)
axs[1].imshow(edited_sample_image)
axs[0].set_axis_off()
axs[1].set_axis_off()
axs[0].set_title(reference_text)
axs[1].set_title(query)
axs[
fig.tight_layout() plt.show()
We’ve recovered the figure from the beginning of the post!