From 5fc2bb5d9cbd888ffa31d9b742e7868421b4529d Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Fri, 28 Jul 2023 08:16:29 -0600 Subject: [PATCH] Information trainer --- jobs/process/BaseSDTrainProcess.py | 27 ++- jobs/process/TrainSDRescaleProcess.py | 133 ++++------- toolkit/train_pipelines.py | 316 ++++++++++++++++++++++++++ 3 files changed, 382 insertions(+), 94 deletions(-) create mode 100644 toolkit/train_pipelines.py diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index cd7e4301..71dbd7ac 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -13,7 +13,8 @@ from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffu sys.path.append(REPOS_ROOT) sys.path.append(os.path.join(REPOS_ROOT, 'leco')) -from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, KDPM2DiscreteScheduler +from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, KDPM2DiscreteScheduler, PNDMScheduler, \ + DDIMScheduler, DDPMScheduler from jobs.process import BaseTrainProcess from toolkit.metadata import get_meta_for_safetensors, load_metadata_from_safetensors @@ -38,8 +39,9 @@ VAE_SCALE_FACTOR = 8 # 2 ** (len(vae.config.block_out_channels) - 1) = 8 class BaseSDTrainProcess(BaseTrainProcess): - def __init__(self, process_id: int, job, config: OrderedDict): + def __init__(self, process_id: int, job, config: OrderedDict, custom_pipeline=None): super().__init__(process_id, job, config) + self.custom_pipeline = custom_pipeline self.step_num = 0 self.start_step = 0 self.device = self.get_conf('device', self.job.device) @@ -271,6 +273,7 @@ class BaseSDTrainProcess(BaseTrainProcess): ) self.print(f"Saved to {file_path}") + self.clean_up_saves() # Called before the model is loaded def hook_before_model_load(self): @@ -467,18 +470,24 @@ class BaseSDTrainProcess(BaseTrainProcess): dtype = get_torch_dtype(self.train_config.dtype) + # TODO handle other schedulers + sch = KDPM2DiscreteScheduler # do our own scheduler - scheduler = KDPM2DiscreteScheduler( + scheduler = sch( num_train_timesteps=1000, beta_start=0.00085, beta_end=0.0120, beta_schedule="scaled_linear", ) if self.model_config.is_xl: - pipe = CustomStableDiffusionXLPipeline.from_single_file( + if self.custom_pipeline is not None: + pipln = self.custom_pipeline + else: + pipln = CustomStableDiffusionXLPipeline + pipe = pipln.from_single_file( self.model_config.name_or_path, dtype=dtype, - scheduler_type='dpm', + scheduler_type='ddpm', device=self.device_torch, ).to(self.device_torch) @@ -490,7 +499,11 @@ class BaseSDTrainProcess(BaseTrainProcess): text_encoder.eval() text_encoder = text_encoders else: - pipe = CustomStableDiffusionPipeline.from_single_file( + if self.custom_pipeline is not None: + pipln = self.custom_pipeline + else: + pipln = CustomStableDiffusionPipeline + pipe = pipln.from_single_file( self.model_config.name_or_path, dtype=dtype, scheduler_type='dpm', @@ -614,7 +627,7 @@ class BaseSDTrainProcess(BaseTrainProcess): if self.has_first_sample_requested: self.print("Generating first sample from first sample config") - self.sample(0, is_first=False) + self.sample(0, is_first=True) # sample first if self.train_config.skip_first_sample: diff --git a/jobs/process/TrainSDRescaleProcess.py b/jobs/process/TrainSDRescaleProcess.py index 9ca1c967..51e4f3d2 100644 --- a/jobs/process/TrainSDRescaleProcess.py +++ b/jobs/process/TrainSDRescaleProcess.py @@ -5,6 +5,7 @@ from collections import OrderedDict import os from typing import Optional +import numpy as np from safetensors.torch import load_file, save_file from tqdm import tqdm @@ -14,6 +15,7 @@ from toolkit.paths import REPOS_ROOT import sys from toolkit.stable_diffusion_model import PromptEmbeds +from toolkit.train_pipelines import TransferStableDiffusionXLPipeline sys.path.append(REPOS_ROOT) sys.path.append(os.path.join(REPOS_ROOT, 'leco')) @@ -61,7 +63,8 @@ class PromptEmbedsCache: class TrainSDRescaleProcess(BaseSDTrainProcess): def __init__(self, process_id: int, job, config: OrderedDict): - super().__init__(process_id, job, config) + # pass our custom pipeline to super so it sets it up + super().__init__(process_id, job, config, custom_pipeline=TransferStableDiffusionXLPipeline) self.step_num = 0 self.start_step = 0 self.device = self.get_conf('device', self.job.device) @@ -173,9 +176,6 @@ class TrainSDRescaleProcess(BaseSDTrainProcess): if prompt is None: raise ValueError(f"Prompt {prompt_txt} is not in cache") - noise_scheduler = self.sd.noise_scheduler - optimizer = self.optimizer - lr_scheduler = self.lr_scheduler loss_function = torch.nn.MSELoss() with torch.no_grad(): @@ -189,13 +189,6 @@ class TrainSDRescaleProcess(BaseSDTrainProcess): timesteps_to = torch.randint( 1, self.train_config.max_denoising_steps, (1,) ).item() - absolute_total_timesteps = 1000 - - max_len_timestep_str = len(str(self.train_config.max_denoising_steps)) - # pad with spaces - timestep_str = str(timesteps_to).rjust(max_len_timestep_str, " ") - new_description = f"{self.job.name} ts: {timestep_str}" - self.progress_bar.set_description(new_description) # get noise latents = self.get_latent_noise( @@ -203,105 +196,71 @@ class TrainSDRescaleProcess(BaseSDTrainProcess): pixel_width=self.rescale_config.from_resolution, ).to(self.device_torch, dtype=dtype) - denoised_fraction = timesteps_to / absolute_total_timesteps self.sd.pipeline.to(self.device_torch) torch.set_default_device(self.device_torch) # turn off progress bar self.sd.pipeline.set_progress_bar_config(disable=True) - pre_train = False + # get random guidance scale from 1.0 to 10.0 + guidance_scale = torch.rand(1).item() * 9.0 + 1.0 - if not pre_train: - # partially denoise the latents - denoised_latents = self.sd.pipeline( - num_inference_steps=self.train_config.max_denoising_steps, - denoising_end=denoised_fraction, - latents=latents, - prompt_embeds=prompt.text_embeds, - negative_prompt_embeds=neutral.text_embeds, - pooled_prompt_embeds=prompt.pooled_embeds, - negative_pooled_prompt_embeds=neutral.pooled_embeds, - output_type="latent", - num_images_per_prompt=self.train_config.batch_size, - guidance_scale=3, - ).images.to(self.device_torch, dtype=dtype) - current_timestep = timesteps_to + loss_arr = [] - else: - denoised_latents = latents - current_timestep = 1 - self.sd.noise_scheduler.set_timesteps( - 1000 - ) + max_len_timestep_str = len(str(self.train_config.max_denoising_steps)) + # pad with spaces + timestep_str = str(timesteps_to).rjust(max_len_timestep_str, " ") + new_description = f"{self.job.name} ts: {timestep_str}" + self.progress_bar.set_description(new_description) - from_prediction = self.sd.pipeline.predict_noise( - latents=denoised_latents, + def pre_condition_callback(target_pred, input_latents): + # handle any manipulations before feeding to our network + reduced_pred = self.reduce_size_fn(target_pred) + reduced_latents = self.reduce_size_fn(input_latents) + self.optimizer.zero_grad() + return reduced_pred, reduced_latents + + def each_step_callback(noise_target, noise_train_pred): + noise_target.requires_grad = False + loss = loss_function(noise_target, noise_train_pred) + loss_arr.append(loss.item()) + loss.backward() + self.optimizer.step() + self.lr_scheduler.step() + self.optimizer.zero_grad() + + # run the pipeline + self.sd.pipeline.transfer_diffuse( + num_inference_steps=timesteps_to, + latents=latents, prompt_embeds=prompt.text_embeds, negative_prompt_embeds=neutral.text_embeds, pooled_prompt_embeds=prompt.pooled_embeds, negative_pooled_prompt_embeds=neutral.pooled_embeds, - timestep=current_timestep, - guidance_scale=1, + output_type="latent", num_images_per_prompt=self.train_config.batch_size, - # predict_noise=True, - num_inference_steps=1000, + guidance_scale=guidance_scale, + network=self.network, + target_unet=self.sd.unet, + pre_condition_callback=pre_condition_callback, + each_step_callback=each_step_callback, ) - reduced_from_prediction = self.reduce_size_fn(from_prediction) - - # get noise prediction at reduced scale - to_denoised_latents = self.reduce_size_fn(denoised_latents).to(self.device_torch, dtype=dtype) - - # start gradient - optimizer.zero_grad() - self.network.multiplier = 1.0 - with self.network: - assert self.network.is_active is True - to_prediction = self.sd.pipeline.predict_noise( - latents=to_denoised_latents, - prompt_embeds=prompt.text_embeds, - negative_prompt_embeds=neutral.text_embeds, - pooled_prompt_embeds=prompt.pooled_embeds, - negative_pooled_prompt_embeds=neutral.pooled_embeds, - timestep=current_timestep, - guidance_scale=1, - num_images_per_prompt=self.train_config.batch_size, - # predict_noise=True, - num_inference_steps=1000, - ) - - reduced_from_prediction.requires_grad = False - from_prediction.requires_grad = False - - loss = loss_function( - reduced_from_prediction, - to_prediction, - ) - - loss_float = loss.item() - - loss = loss.to(self.device_torch) - - loss.backward() - optimizer.step() - lr_scheduler.step() - - del ( - reduced_from_prediction, - from_prediction, - to_denoised_latents, - to_prediction, - latents, - ) flush() # reset network self.network.multiplier = 1.0 + # average losses + s = 0 + for num in loss_arr: + s += num + + avg_loss = s / len(loss_arr) + loss_dict = OrderedDict( - {'loss': loss_float}, + {'loss': avg_loss}, ) return loss_dict diff --git a/toolkit/train_pipelines.py b/toolkit/train_pipelines.py new file mode 100644 index 00000000..b9cc623c --- /dev/null +++ b/toolkit/train_pipelines.py @@ -0,0 +1,316 @@ +from typing import Optional, Tuple, Callable, Dict, Any, Union, List + +import torch +from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput +from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg + +from toolkit.lora_special import LoRASpecialNetwork +from toolkit.pipelines import CustomStableDiffusionXLPipeline + + +class TransferStableDiffusionXLPipeline(CustomStableDiffusionXLPipeline): + def transfer_diffuse( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + target_unet: Optional[torch.nn.Module] = None, + pre_condition_callback = None, + each_step_callback = None, + network: Optional[LoRASpecialNetwork] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + + Examples: + + Returns: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + add_time_ids = self._get_add_time_ids( + original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 7.1 Apply denoising_end + if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1: + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + conditioned_noise_pred, conditioned_latent_model_input = pre_condition_callback( + noise_pred.clone().detach(), + latent_model_input.clone().detach(), + ) + + # start grad + with torch.enable_grad(): + with network: + assert network.is_active + noise_train_pred = target_unet( + conditioned_latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + each_step_callback(conditioned_noise_pred, noise_train_pred) + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) +