!pip install -Uq diffusers transformers fastcore
A code-first introduction to Stable Diffusion using the Diffusers library.
Importing Libraries
import logging
from pathlib import Path
import matplotlib.pyplot as plt
import torch
from diffusers import StableDiffusionPipeline
from fastcore.all import concat
from huggingface_hub import notebook_login
from PIL import Image
logging.disable(logging.WARNING)
1)
torch.manual_seed(if not (Path.home()/'.cache/huggingface'/'token').exists(): notebook_login()
def image_grid(imgs, rows, cols):
= imgs[0].size
w,h = Image.new('RGB', size=(cols*w, rows*h))
grid for i, img in enumerate(imgs): grid.paste(img, box=(i%cols*w, i//cols*h))
return grid
Stable Diffusion
Before diving into the code, let’s first take a bird’s eye view of how Stable Diffusion works. Stable Diffusion is a latent diffusion algorithm, meaning it works with the images in latent space, that is compresses the images into a smaller dimesnion using an autoencoder. This makes the process both faster and memory efficient. In the forward process of diffusion, Gaussian noise is added to the images using an encoder over a number of steps till the images is indistinguishable from white noise. In the backward process the images are iteratively denoised to generate new samples. The denoising is done using a unet model that predicts and subtracts noise which is then subtracted from the image.
The Diffusion Pipeline
Diffusers library by Hugging Face is arguably the best way to use Stable Diffusion. The Diffusers library uses StableDiffusionPipeline
, which is an end-to-end diffusion inference pipeline to combine all the steps in the diffusion process to easily do inference with the model. However, in this tutorial we will not use the pipeline directly ourselves but rather build it ourselves from its components to get a better understanding of how stable diffusion works under the hood. This also gives us more control over the output generated by the model.
= StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", variant="fp16", torch_dtype=torch.float16).to("cuda") pipe
/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:88: UserWarning:
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
warnings.warn(
Looking inside the pipeline
The diffusion pipeline combines all the components of stable diffusion for convenient inference. Let’s build the pipeline back by putting together its components.
The main componets of a latent diffusion algorithm are:
- An autoencoder (VAE)- for compressing the image in the latent space
- A u-net- for predicting noise in the latent image
- A text encoder: for creating embeddings of text and images.
The output of a u-net, the noise residual, is used for gradual denoising of the latent image representation via a scheduler algorithm.
del pipe
First, we need a text tokenizer and text encoder. Stable Diffusion uses the CLIP encoder by Open AI, so we’ll be using its weights:
from transformers import CLIPTextModel, CLIPTokenizer
= CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16)
tokenizer = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.float16).to("cuda") text_encoder
Next, we’ll use the vae
and unet
models.
from diffusers import AutoencoderKL, UNet2DConditionModel
# Here we use a different VAE to the original release, which has been fine-tuned for more steps
= AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema", torch_dtype=torch.float16).to("cuda")
vae = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet", torch_dtype=torch.float16).to("cuda") unet
The standard pipeline used the PNDM scheduler, but we’ll use the K-LMS scheduler. We need to use the same noising schedule as was used during training. The schedule is defined by the time step at which the noise is added and the amount of noise that is added, which is given by the beta parameters. For the K-LMS scheduler, this is how the betas evolve:
= 0.00085, 0.012
beta_start, beta_end ** 0.5, beta_end ** 0.5, 1000) ** 2)
plt.plot(torch.linspace(beta_start 'Timestep')
plt.xlabel('β') plt.ylabel(
Text(0, 0.5, 'β')
from diffusers import LMSDiscreteScheduler
= LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule='scaled_linear', num_train_timesteps=1000) scheduler
We’ll now define the parameters we’ll use for generation
= "Batman drinking coffee while sitting in a cafe"
prompt
= 512
height = 512
width = 70
num_inference_steps = 7.5
guidance_scale = 1 batch_size
Next, we tokenize the prompt. To ensure that we have same number of tokens for each prompt, we’ll use max length padding and truncation.
= tokenizer(prompt, padding='max_length', max_length=tokenizer.model_max_length, truncation=True, return_tensors='pt')
text_input 'input_ids'] text_input[
tensor([[49406, 7223, 5778, 2453, 1519, 4919, 530, 320, 4979, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,
49407, 49407, 49407, 49407, 49407, 49407, 49407]])
49407), tokenizer.decode(49406) tokenizer.decode(
('<|endoftext|>', '<|startoftext|>')
'attention_mask'] text_input[
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0]])
The text encoder gives us the embeddings for the text prompt we have used.
= text_encoder(text_input.input_ids.to('cuda'))[0].half()
text_embeddings text_embeddings.shape
torch.Size([1, 77, 768])
We’ll also get the embeddings required for unconditional generation, which is achieved with an empty string; the model is free to go anywhere as long as the results as reasonable. These embeddings will be used for classifier-free guidance.
= text_input.input_ids.shape[-1]
max_length = tokenizer(['']*batch_size, padding='max_length',max_length=max_length, return_tensors='pt')
uncond_input = text_encoder(uncond_input.input_ids.to('cuda'))[0].half()
uncond_embeddings uncond_embeddings.shape
torch.Size([1, 77, 768])
For classifier-free guidance, we need to do two forward passes. One with the conditioned input (text_embeddings
), and another with the unconditional embeddings (uncond_embeddings
). In practice, we can concatenate both into a single batch to avoid doing two forward passes.
= torch.cat([text_embeddings, uncond_embeddings])
text_embeddings text_embeddings.shape
torch.Size([2, 77, 768])
To start the denoising process, we’ll start from pure Gaussian noise. These are our initial latents.
42)
torch.manual_seed(= torch.randn((batch_size, unet.in_channels, height//8, width//8))
latents = latents.to('cuda').half()
latents latents.shape
FutureWarning: Accessing config attribute `in_channels` directly via 'UNet2DConditionModel' object attribute is deprecated. Please access 'in_channels' over 'UNet2DConditionModel's config object instead, e.g. 'unet.config.in_channels'.
latents = torch.randn((batch_size, unet.in_channels, height//8, width//8))
torch.Size([1, 4, 64, 64])
We initialize the scheduler with num_inference_steps
. This will prepare the internal state to be used during denoising.
scheduler.set_timesteps(num_inference_steps)
We scale the latents with the initial standard deviation required by the scheduler.
= latents * scheduler.init_noise_sigma latents
We are now ready to write the denoising loop. The timesteps go from 999
to 0
following a given schedule.
scheduler.timesteps
tensor([999.0000, 984.5217, 970.0435, 955.5652, 941.0870, 926.6087, 912.1304,
897.6522, 883.1739, 868.6957, 854.2174, 839.7391, 825.2609, 810.7826,
796.3043, 781.8261, 767.3478, 752.8696, 738.3913, 723.9130, 709.4348,
694.9565, 680.4783, 666.0000, 651.5217, 637.0435, 622.5652, 608.0870,
593.6087, 579.1304, 564.6522, 550.1739, 535.6957, 521.2174, 506.7391,
492.2609, 477.7826, 463.3044, 448.8261, 434.3478, 419.8696, 405.3913,
390.9131, 376.4348, 361.9565, 347.4783, 333.0000, 318.5217, 304.0435,
289.5652, 275.0869, 260.6087, 246.1304, 231.6522, 217.1739, 202.6956,
188.2174, 173.7391, 159.2609, 144.7826, 130.3044, 115.8261, 101.3478,
86.8696, 72.3913, 57.9130, 43.4348, 28.9565, 14.4783, 0.0000])
scheduler.sigmas
tensor([14.6146, 13.3974, 12.3033, 11.3184, 10.4301, 9.6279, 8.9020, 8.2443,
7.6472, 7.1044, 6.6102, 6.1594, 5.7477, 5.3709, 5.0258, 4.7090,
4.4178, 4.1497, 3.9026, 3.6744, 3.4634, 3.2680, 3.0867, 2.9183,
2.7616, 2.6157, 2.4794, 2.3521, 2.2330, 2.1213, 2.0165, 1.9180,
1.8252, 1.7378, 1.6552, 1.5771, 1.5031, 1.4330, 1.3664, 1.3030,
1.2427, 1.1852, 1.1302, 1.0776, 1.0272, 0.9788, 0.9324, 0.8876,
0.8445, 0.8029, 0.7626, 0.7236, 0.6858, 0.6490, 0.6131, 0.5781,
0.5438, 0.5102, 0.4770, 0.4443, 0.4118, 0.3795, 0.3470, 0.3141,
0.2805, 0.2455, 0.2084, 0.1672, 0.1174, 0.0292, 0.0000])
-1]) plt.plot(scheduler.timesteps, scheduler.sigmas[:
from tqdm.auto import tqdm
for i, t in enumerate(tqdm(scheduler.timesteps)):
input = torch.cat([latents]*2)
input = scheduler.scale_model_input(input, t)
#predict the noise residual
with torch.no_grad():
= unet(input, t, encoder_hidden_states=text_embeddings).sample
pred
#perform guidance
= pred.chunk(2)
pred_text, pred_uncond = pred_uncond + guidance_scale * (pred_text - pred_uncond)
pred
#compute the previous 'noisy' sample
= scheduler.step(pred, t, latents).prev_sample latents
Now, our latents contain the denoised representation of the image. We use the vae
decoder to convert it back to the pixel-space.
with torch.no_grad(): image = vae.decode(1 / 0.18215 * latents).sample
And finally, lets convert it back to PIL
to display the image.
The arithmetic acrobats are done to convert the decoded image in a form that can be used correctly by the PIL
.
= (image / 2 + 0.5).clamp(0, 1)
image = image[0].detach().cpu().permute(1, 2, 0).numpy()
image = (image * 255).round().astype("uint8")
image Image.fromarray(image)
Voila! Here’s our Batman with his coffee.
This was the text2image pipeline. Now, let’s build the img2img pipeline. Instead of starting with pure noise, we’ll start with a certain image and add noise to it. We skip the initial steps of the denoising process and pretend that the given image is what the algorithm came up with.
!curl --output philosophical_dog.jpg 'https://puppytoob.com/wp-content/uploads/2013/04/Dogs_In_Wind_8.jpg'
% Total % Received % Xferd Average Speed Time Time Time Current
Dload Upload Total Spent Left Speed
0 0 0 0 0 0 0 0 --:--:-- --:--:-- --:--:-- 0100 25821 100 25821 0 0 134k 0 --:--:-- --:--:-- --:--:-- 134k
= Image.open('philosophical_dog.jpg').resize((512, 512))
input_image input_image
Let’s create a function to encode our image into the latent space using vae
.
from torchvision import transforms as tfms
def pil_to_latent(input_im):
# Single image -> single latent in a batch (so size 1, 4, 64, 64)
with torch.no_grad():
= vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to('cuda').half()*2 - 1)
latent return 0.18215 * latent.latent_dist.sample()
… and a function to get the image back from latent representation
def latents_to_pil(latents):
# batch of latents -> list of images
= (1 / 0.18215) * latents
latents with torch.no_grad():
= vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
image = (image * 255).round().astype("uint8")
images = [Image.fromarray(image) for image in images]
pil_images return pil_images
= pil_to_latent(input_image)
encoded encoded.shape
torch.Size([1, 4, 64, 64])
Now, let’s repeat the process we did above, except that we will start with the input image and not with pure noise.
#set parameters
= ['A painting in the style of Vincent Van Gogh']
prompt = 512
height = 512
width = 70
num_inference_steps = 7.5
guidance_scale = 1
batch_size
# Prep text embeddings
= tokenizer(prompt, padding='max_length', max_length=tokenizer.model_max_length, truncation=True, return_tensors='pt')
text_input = text_encoder(text_input.input_ids.to('cuda'))[0].half()
text_embeddings = text_input.input_ids.shape[-1]
max_length = tokenizer([""]*batch_size, padding='max_length', max_length=max_length, truncation=True, return_tensors='pt')
uncond_input = text_encoder(uncond_input.input_ids.to('cuda'))[0].half()
uncond_embeddings = torch.cat([uncond_embeddings, text_embeddings])
text_embeddings
#Prep scheduler (set num of inference steps)
scheduler.set_timesteps(num_inference_steps)
#Prep latents
= 30
start_step = scheduler.sigmas[start_step]
start_sigma = torch.randn_like(encoded)
noise = scheduler.add_noise(encoded, noise, timesteps=torch.tensor([scheduler.timesteps[start_step]]))
latents = latents.to('cuda').half()
latents
#Denoising loop
for i, t in enumerate(tqdm(scheduler.timesteps)):
if i > start_step:
input = torch.concat([latents]*2)
input = scheduler.scale_model_input(input, t)
#predict noise residual
with torch.no_grad():
= unet(input, t, encoder_hidden_states=text_embeddings).sample
pred
#perform guidance
= pred.chunk(2)
pred_uncond, pred_text = pred_uncond + guidance_scale * (pred_text - pred_uncond)
pred
#compute prev noisy sample
= scheduler.step(pred, t, latents).prev_sample
latents
0] latents_to_pil(latents)[
And here’s our dog, ruminating about the meaninglessness of life.