Pipelines working on SDXL for noise prediction

This commit is contained in:
Jaret Burkett
2023-07-27 11:24:33 -06:00
parent 6ab8b8b0f1
commit 596e57a6a6
3 changed files with 162 additions and 39 deletions

View File

@@ -18,7 +18,7 @@ from toolkit.pipelines import CustomStableDiffusionXLPipeline
sys.path.append(REPOS_ROOT)
sys.path.append(os.path.join(REPOS_ROOT, 'leco'))
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, KDPM2DiscreteScheduler
from jobs.process import BaseTrainProcess
from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors
@@ -500,13 +500,21 @@ class BaseSDTrainProcess(BaseTrainProcess):
dtype = get_torch_dtype(self.train_config.dtype)
if self.model_config.is_xl:
# do our own scheduler
scheduler = KDPM2DiscreteScheduler(
num_train_timesteps=1000,
beta_start=0.00085,
beta_end=0.0120,
beta_schedule="scaled_linear",
)
pipe = CustomStableDiffusionXLPipeline.from_single_file(
self.model_config.name_or_path,
dtype=dtype,
scheduler_type='dpm',
device=self.device_torch
device=self.device_torch,
).to(self.device_torch)
pipe.scheduler = scheduler
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
tokenizer = [pipe.tokenizer, pipe.tokenizer_2]
unet = pipe.unet
@@ -637,10 +645,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.progress_bar = tqdm(
total=self.train_config.steps,
desc=self.job.name,
leave=True
leave=True,
initial=self.step_num,
iterable=range(0, self.train_config.steps),
)
# set it to our current step in case it was updated from a load
self.progress_bar.update(self.step_num)
# self.step_num = 0
for step in range(self.step_num, self.train_config.steps):
# todo handle dataloader here maybe, not sure

View File

@@ -171,9 +171,9 @@ class TrainSDRescaleProcess(BaseSDTrainProcess):
loss_function = torch.nn.MSELoss()
with torch.no_grad():
self.sd.noise_scheduler.set_timesteps(
self.train_config.max_denoising_steps, device=self.device_torch
)
# self.sd.noise_scheduler.set_timesteps(
# self.train_config.max_denoising_steps, device=self.device_torch
# )
self.optimizer.zero_grad()
@@ -183,6 +183,12 @@ class TrainSDRescaleProcess(BaseSDTrainProcess):
).item()
absolute_total_timesteps = 1000
max_len_timestep_str = len(str(self.train_config.max_denoising_steps))
# pad with spaces
timestep_str = str(timesteps_to).rjust(max_len_timestep_str, " ")
new_description = f"{self.job.name} ts: {timestep_str}"
self.progress_bar.set_description(new_description)
# get noise
latents = self.get_latent_noise(
pixel_height=self.rescale_config.from_resolution,
@@ -190,21 +196,37 @@ class TrainSDRescaleProcess(BaseSDTrainProcess):
).to(self.device_torch, dtype=dtype)
denoised_fraction = timesteps_to / absolute_total_timesteps
self.sd.pipeline.to(self.device_torch)
torch.set_default_device(self.device_torch)
denoised_latents = self.sd.pipeline(
num_inference_steps=1000,
denoising_end=denoised_fraction,
latents=latents,
prompt_embeds=prompt.text_embeds,
negative_prompt_embeds=neutral.text_embeds,
pooled_prompt_embeds=prompt.pooled_embeds,
negative_pooled_prompt_embeds=neutral.pooled_embeds,
output_type="latent",
num_images_per_prompt=self.train_config.batch_size,
guidance_scale=3,
).images.to(self.device_torch, dtype=dtype)
# turn off progress bar
self.sd.pipeline.set_progress_bar_config(disable=True)
current_timestep = timesteps_to
pre_train = False
if not pre_train:
# partially denoise the latents
denoised_latents = self.sd.pipeline(
num_inference_steps=self.train_config.max_denoising_steps,
denoising_end=denoised_fraction,
latents=latents,
prompt_embeds=prompt.text_embeds,
negative_prompt_embeds=neutral.text_embeds,
pooled_prompt_embeds=prompt.pooled_embeds,
negative_pooled_prompt_embeds=neutral.pooled_embeds,
output_type="latent",
num_images_per_prompt=self.train_config.batch_size,
guidance_scale=3,
).images.to(self.device_torch, dtype=dtype)
current_timestep = timesteps_to
else:
denoised_latents = latents
current_timestep = 1
self.sd.noise_scheduler.set_timesteps(
1000
)
from_prediction = self.sd.pipeline.predict_noise(
latents=denoised_latents,
@@ -213,10 +235,13 @@ class TrainSDRescaleProcess(BaseSDTrainProcess):
pooled_prompt_embeds=prompt.pooled_embeds,
negative_pooled_prompt_embeds=neutral.pooled_embeds,
timestep=current_timestep,
guidance_scale=2
guidance_scale=1,
num_images_per_prompt=self.train_config.batch_size,
# predict_noise=True,
num_inference_steps=1000,
)
reduced_from_prediction = self.reduce_size_fn(from_prediction).to("cpu", dtype=torch.float32)
reduced_from_prediction = self.reduce_size_fn(from_prediction)
# get noise prediction at reduced scale
to_denoised_latents = self.reduce_size_fn(denoised_latents).to(self.device_torch, dtype=dtype)
@@ -233,7 +258,10 @@ class TrainSDRescaleProcess(BaseSDTrainProcess):
pooled_prompt_embeds=prompt.pooled_embeds,
negative_pooled_prompt_embeds=neutral.pooled_embeds,
timestep=current_timestep,
guidance_scale=2
guidance_scale=1,
num_images_per_prompt=self.train_config.batch_size,
# predict_noise=True,
num_inference_steps=1000,
)
reduced_from_prediction.requires_grad = False

View File

@@ -1,4 +1,4 @@
from typing import Union, List, Optional, Dict, Any, Tuple
from typing import Union, List, Optional, Dict, Any, Tuple, Callable
import torch
from diffusers import StableDiffusionXLPipeline
@@ -13,17 +13,32 @@ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline):
self,
prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
denoising_end: Optional[float] = None,
guidance_scale: float = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
timestep: Optional[int] = 1,
original_size: Optional[Tuple[int, int]] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
target_size: Optional[Tuple[int, int]] = None,
# predict_noise: bool = False,
timestep: Optional[int] = None,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -35,6 +50,20 @@ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline):
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in both text-encoders
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
denoising_end (`float`, *optional*):
When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
completed before it is intentionally prematurely terminated. As a result, the returned sample will
still retain a substantial amount of noise as determined by the discrete timesteps selected by the
scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
"Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -48,6 +77,14 @@ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline):
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
@@ -69,6 +106,15 @@ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline):
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
@@ -78,14 +124,59 @@ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline):
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Guidance rescale factor should fix overexposure when using zero terminal SNR.
timestep (`int`, *optional*, defaults to `1`):
The timestep at which to generate the image. If not specified, the last timestep is used.
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
`original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
`crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
`crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
`crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
For most cases, `target_size` should be set to the desired height and width of the generated image. If
not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
Examples:
Returns:
torch.FloatTensor: Predicted noise
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
`tuple`. When returning a tuple, the first element is a list with the generated images.
"""
# if not predict_noise:
# # call parent
# return super().__call__(
# prompt=prompt,
# prompt_2=prompt_2,
# height=height,
# width=width,
# num_inference_steps=num_inference_steps,
# denoising_end=denoising_end,
# guidance_scale=guidance_scale,
# negative_prompt=negative_prompt,
# negative_prompt_2=negative_prompt_2,
# num_images_per_prompt=num_images_per_prompt,
# eta=eta,
# generator=generator,
# latents=latents,
# prompt_embeds=prompt_embeds,
# negative_prompt_embeds=negative_prompt_embeds,
# pooled_prompt_embeds=pooled_prompt_embeds,
# negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
# output_type=output_type,
# return_dict=return_dict,
# callback=callback,
# callback_steps=callback_steps,
# cross_attention_kwargs=cross_attention_kwargs,
# guidance_rescale=guidance_rescale,
# original_size=original_size,
# crops_coords_top_left=crops_coords_top_left,
# target_size=target_size,
# )
# 0. Default height and width to unet
height = self.default_sample_size * self.vae_scale_factor
width = self.default_sample_size * self.vae_scale_factor
@@ -106,16 +197,12 @@ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline):
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
# do_classifier_free_guidance = guidance_scale > 1.0
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
text_encoder_lora_scale = (
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
)
num_images_per_prompt = 1
(
prompt_embeds,
negative_prompt_embeds,
@@ -137,7 +224,7 @@ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline):
)
# 4. Prepare timesteps
self.scheduler.set_timesteps(1, device=device)
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
@@ -150,16 +237,15 @@ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline):
width,
prompt_embeds.dtype,
device,
None,
generator,
latents,
)
# 7. Prepare added time ids & embeddings
add_text_embeds = pooled_prompt_embeds
crops_coords_top_left: Tuple[int, int] = (0, 0)
add_time_ids = self._get_add_time_ids(
original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
).to(device)
).to(device) # TODO DOES NOT CAST ORIGINALLY
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
@@ -172,13 +258,13 @@ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline):
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, timesteps)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep)
# predict the noise residual
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
noise_pred = self.unet(
latent_model_input,
timestep=timestep,
timestep,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
added_cond_kwargs=added_cond_kwargs,