Pipelines working on SDXL for noise prediction

This commit is contained in:
Jaret Burkett
2023-07-27 11:24:33 -06:00
parent 6ab8b8b0f1
commit 596e57a6a6
3 changed files with 162 additions and 39 deletions

View File

@@ -1,4 +1,4 @@
from typing import Union, List, Optional, Dict, Any, Tuple
from typing import Union, List, Optional, Dict, Any, Tuple, Callable
import torch
from diffusers import StableDiffusionXLPipeline
@@ -13,17 +13,32 @@ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline):
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
denoising_end: Optional[float] = None,
guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
timestep: Optional[int] = 1,
original_size: Optional[Tuple[int, int]] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
target_size: Optional[Tuple[int, int]] = None,
# predict_noise: bool = False,
timestep: Optional[int] = None,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -35,6 +50,20 @@ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline):
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in both text-encoders
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
denoising_end (`float`, *optional*):
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
completed before it is intentionally prematurely terminated. As a result, the returned sample will
still retain a substantial amount of noise as determined by the discrete timesteps selected by the
scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
"Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -48,6 +77,14 @@ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline):
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
@@ -69,6 +106,15 @@ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline):
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
@@ -78,14 +124,59 @@ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline):
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Guidance rescale factor should fix overexposure when using zero terminal SNR.
timestep (`int`, *optional*, defaults to `1`):
The timestep at which to generate the image. If not specified, the last timestep is used.
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
`original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
For most cases, `target_size` should be set to the desired height and width of the generated image. If
not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
Examples:
Returns:
torch.FloatTensor: Predicted noise
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
# if not predict_noise:
# # call parent
# return super().__call__(
# prompt=prompt,
# prompt_2=prompt_2,
# height=height,
# width=width,
# num_inference_steps=num_inference_steps,
# denoising_end=denoising_end,
# guidance_scale=guidance_scale,
# negative_prompt=negative_prompt,
# negative_prompt_2=negative_prompt_2,
# num_images_per_prompt=num_images_per_prompt,
# eta=eta,
# generator=generator,
# latents=latents,
# prompt_embeds=prompt_embeds,
# negative_prompt_embeds=negative_prompt_embeds,
# pooled_prompt_embeds=pooled_prompt_embeds,
# negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
# output_type=output_type,
# return_dict=return_dict,
# callback=callback,
# callback_steps=callback_steps,
# cross_attention_kwargs=cross_attention_kwargs,
# guidance_rescale=guidance_rescale,
# original_size=original_size,
# crops_coords_top_left=crops_coords_top_left,
# target_size=target_size,
# )
# 0. Default height and width to unet
height = self.default_sample_size * self.vae_scale_factor
width = self.default_sample_size * self.vae_scale_factor
@@ -106,16 +197,12 @@ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline):
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
# do_classifier_free_guidance = guidance_scale > 1.0
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
num_images_per_prompt = 1
(
prompt_embeds,
negative_prompt_embeds,
@@ -137,7 +224,7 @@ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline):
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(1, device=device)
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
@@ -150,16 +237,15 @@ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline):
width,
prompt_embeds.dtype,
device,
None,
generator,
latents,
)
# 7. Prepare added time ids & embeddings
add_text_embeds = pooled_prompt_embeds
crops_coords_top_left: Tuple[int, int] = (0, 0)
add_time_ids = self._get_add_time_ids(
original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
).to(device)
).to(device) # TODO DOES NOT CAST ORIGINALLY
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
@@ -172,13 +258,13 @@ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline):
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, timesteps)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep)
# predict the noise residual
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
noise_pred = self.unet(
latent_model_input,
timestep=timestep,
timestep,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,