From b2e2e4bf474bbd36f1ba3568d4df2894d23b1d1c Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Thu, 27 Jul 2023 12:34:48 -0600 Subject: [PATCH] Added sd1.5 and 2.1 do the diffusers pipeline flow --- jobs/process/BaseSDTrainProcess.py | 173 ++++++++-------- jobs/process/TrainSDRescaleProcess.py | 16 +- toolkit/pipelines.py | 274 ++++++++++++++++++++++++-- 3 files changed, 349 insertions(+), 114 deletions(-) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 85e39a23..cd7e4301 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -3,17 +3,12 @@ import time from collections import OrderedDict import os -import diffusers -from safetensors import safe_open - -from library import sdxl_train_util, sdxl_model_util -from toolkit.kohya_model_util import load_vae from toolkit.lora_special import LoRASpecialNetwork from toolkit.optimizer import get_optimizer from toolkit.paths import REPOS_ROOT import sys -from toolkit.pipelines import CustomStableDiffusionXLPipeline +from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffusionPipeline sys.path.append(REPOS_ROOT) sys.path.append(os.path.join(REPOS_ROOT, 'leco')) @@ -55,8 +50,13 @@ class BaseSDTrainProcess(BaseTrainProcess): self.model_config = ModelConfig(**self.get_conf('model', {})) self.save_config = SaveConfig(**self.get_conf('save', {})) self.sample_config = SampleConfig(**self.get_conf('sample', {})) - self.first_sample_config = SampleConfig( - **self.get_conf('first_sample', {})) if 'first_sample' in self.config else self.sample_config + first_sample_config = self.get_conf('first_sample', None) + if first_sample_config is not None: + self.has_first_sample_requested = True + self.first_sample_config = SampleConfig(**first_sample_config) + else: + self.has_first_sample_requested = False + self.first_sample_config = self.sample_config self.logging_config = LogingConfig(**self.get_conf('logging', {})) self.optimizer = None self.lr_scheduler = None @@ -101,19 +101,7 @@ class BaseSDTrainProcess(BaseTrainProcess): # self.sd.text_encoder.to(self.device_torch) # self.sd.tokenizer.to(self.device_torch) # TODO add clip skip - if self.sd.is_xl: - pipeline = self.sd.pipeline - else: - pipeline = StableDiffusionPipeline( - vae=self.sd.vae, - unet=self.sd.unet, - text_encoder=self.sd.text_encoder, - tokenizer=self.sd.tokenizer, - scheduler=self.sd.noise_scheduler, - safety_checker=None, - feature_extractor=None, - requires_safety_checker=False, - ) + pipeline = self.sd.pipeline # disable progress bar pipeline.set_progress_bar_config(disable=True) @@ -172,24 +160,16 @@ class BaseSDTrainProcess(BaseTrainProcess): torch.manual_seed(current_seed) torch.cuda.manual_seed(current_seed) - if self.sd.is_xl: - img = pipeline( - prompt, - height=height, - width=width, - num_inference_steps=sample_config.sample_steps, - guidance_scale=sample_config.guidance_scale, - negative_prompt=neg, - ).images[0] - else: - img = pipeline( - prompt, - height=height, - width=width, - num_inference_steps=sample_config.sample_steps, - guidance_scale=sample_config.guidance_scale, - negative_prompt=neg, - ).images[0] + img = pipeline( + prompt=prompt, + prompt_2=prompt, + negative_prompt=neg, + negative_prompt_2=neg, + height=height, + width=width, + num_inference_steps=sample_config.sample_steps, + guidance_scale=sample_config.guidance_scale, + ).images[0] step_num = '' if step is not None: @@ -202,9 +182,6 @@ class BaseSDTrainProcess(BaseTrainProcess): output_path = os.path.join(sample_folder, filename) img.save(output_path) - # clear pipeline and cache to reduce vram usage - if not self.sd.is_xl: - del pipeline torch.cuda.empty_cache() # restore training state @@ -230,9 +207,12 @@ class BaseSDTrainProcess(BaseTrainProcess): }) if self.model_config.is_v2: dict['ss_v2'] = True + dict['ss_base_model_version'] = 'sd_2.1' - if self.model_config.is_xl: + elif self.model_config.is_xl: dict['ss_base_model_version'] = 'sdxl_1.0' + else: + dict['ss_base_model_version'] = 'sd_1.5' dict['ss_output_name'] = self.job.name @@ -313,7 +293,6 @@ class BaseSDTrainProcess(BaseTrainProcess): ): if height is None and pixel_height is None: raise ValueError("height or pixel_height must be specified") - raise ValueError("height or pixel_height must be specified") if width is None and pixel_width is None: raise ValueError("width or pixel_width must be specified") if height is None: @@ -371,7 +350,6 @@ class BaseSDTrainProcess(BaseTrainProcess): ): pass - def predict_noise( self, latents: torch.FloatTensor, @@ -386,17 +364,7 @@ class BaseSDTrainProcess(BaseTrainProcess): if add_time_ids is None: add_time_ids = self.get_time_ids_from_latents(latents) # todo LECOs code looks like it is omitting noise_pred - # noise_pred = train_util.predict_noise_xl( - # self.sd.unet, - # self.sd.noise_scheduler, - # timestep, - # latents, - # text_embeddings.text_embeds, - # text_embeddings.pooled_embeds, - # add_time_ids, - # guidance_scale=guidance_scale, - # guidance_rescale=guidance_rescale - # ) + latent_model_input = torch.cat([latents] * 2) latent_model_input = self.sd.noise_scheduler.scale_model_input(latent_model_input, timestep) @@ -499,64 +467,66 @@ class BaseSDTrainProcess(BaseTrainProcess): dtype = get_torch_dtype(self.train_config.dtype) + # do our own scheduler + scheduler = KDPM2DiscreteScheduler( + num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.0120, + beta_schedule="scaled_linear", + ) if self.model_config.is_xl: - # do our own scheduler - scheduler = KDPM2DiscreteScheduler( - num_train_timesteps=1000, - beta_start=0.00085, - beta_end=0.0120, - beta_schedule="scaled_linear", - ) - pipe = CustomStableDiffusionXLPipeline.from_single_file( self.model_config.name_or_path, dtype=dtype, scheduler_type='dpm', device=self.device_torch, ).to(self.device_torch) - pipe.scheduler = scheduler + text_encoders = [pipe.text_encoder, pipe.text_encoder_2] tokenizer = [pipe.tokenizer, pipe.tokenizer_2] - unet = pipe.unet - noise_scheduler = pipe.scheduler - vae = pipe.vae.to('cpu', dtype=dtype) - vae.eval() - vae.set_use_memory_efficient_attention_xformers(True) - for text_encoder in text_encoders: text_encoder.to(self.device_torch, dtype=dtype) text_encoder.requires_grad_(False) text_encoder.eval() - text_encoder = text_encoders - tokenizer = tokenizer - flush() - - else: - tokenizer, text_encoder, unet, noise_scheduler = model_util.load_models( + pipe = CustomStableDiffusionPipeline.from_single_file( self.model_config.name_or_path, - scheduler_name=self.train_config.noise_scheduler, - v2=self.model_config.is_v2, - v_pred=self.model_config.is_v_pred, - ) - + dtype=dtype, + scheduler_type='dpm', + device=self.device_torch, + load_safety_checker=False, + ).to(self.device_torch) + pipe.register_to_config(requires_safety_checker=False) + text_encoder = pipe.text_encoder text_encoder.to(self.device_torch, dtype=dtype) + text_encoder.requires_grad_(False) text_encoder.eval() - vae = load_vae(self.model_config.name_or_path, dtype=dtype).to('cpu', dtype=dtype) - vae.eval() - pipe = None + tokenizer = pipe.tokenizer + + # scheduler doesn't get set sometimes, so we set it here + pipe.scheduler = scheduler + + unet = pipe.unet + noise_scheduler = pipe.scheduler + vae = pipe.vae.to('cpu', dtype=dtype) + vae.eval() + vae.requires_grad_(False) flush() - - # just for now or of we want to load a custom one - # put on cpu for now, we only need it when sampling - # vae = load_vae(self.model_config.name_or_path, dtype=dtype).to('cpu', dtype=dtype) - # vae.eval() - self.sd = StableDiffusion(vae, tokenizer, text_encoder, unet, noise_scheduler, is_xl=self.model_config.is_xl, pipeline=pipe) + self.sd = StableDiffusion( + vae, + tokenizer, + text_encoder, + unet, + noise_scheduler, + is_xl=self.model_config.is_xl, + pipeline=pipe + ) unet.to(self.device_torch, dtype=dtype) if self.train_config.xformers: + vae.set_use_memory_efficient_attention_xformers(True) unet.enable_xformers_memory_efficient_attention() if self.train_config.gradient_checkpointing: unet.enable_gradient_checkpointing() @@ -602,19 +572,26 @@ class BaseSDTrainProcess(BaseTrainProcess): self.network.multiplier = 1.0 - else: params = [] # assume dreambooth/finetune if self.train_config.train_text_encoder: - text_encoder.requires_grad_(True) - text_encoder.train() - params += text_encoder.parameters() + if self.sd.is_xl: + for te in text_encoder: + te.requires_grad_(True) + te.train() + params += te.parameters() + else: + text_encoder.requires_grad_(True) + text_encoder.train() + params += text_encoder.parameters() if self.train_config.train_unet: unet.requires_grad_(True) unet.train() params += unet.parameters() + # TODO recover save if training network. Maybe load from beginning + ### HOOK ### params = self.hook_add_extra_train_params(params) @@ -635,12 +612,16 @@ class BaseSDTrainProcess(BaseTrainProcess): ### HOOK ### self.hook_before_train_loop() + if self.has_first_sample_requested: + self.print("Generating first sample from first sample config") + self.sample(0, is_first=False) + # sample first if self.train_config.skip_first_sample: self.print("Skipping first sample due to config setting") else: self.print("Generating baseline samples before training") - self.sample(0, is_first=True) + self.sample(0) self.progress_bar = tqdm( total=self.train_config.steps, diff --git a/jobs/process/TrainSDRescaleProcess.py b/jobs/process/TrainSDRescaleProcess.py index 86b39f32..9ca1c967 100644 --- a/jobs/process/TrainSDRescaleProcess.py +++ b/jobs/process/TrainSDRescaleProcess.py @@ -129,9 +129,12 @@ class TrainSDRescaleProcess(BaseSDTrainProcess): print(f"Saving prompt tensors to {self.rescale_config.prompt_tensors}") state_dict = {} for prompt_txt, prompt_embeds in cache.prompts.items(): - state_dict[f"te:{prompt_txt}"] = prompt_embeds.text_embeds.to("cpu", dtype=get_torch_dtype('fp16')) + state_dict[f"te:{prompt_txt}"] = prompt_embeds.text_embeds.to("cpu", + dtype=get_torch_dtype('fp16')) if prompt_embeds.pooled_embeds is not None: - state_dict[f"pe:{prompt_txt}"] = prompt_embeds.pooled_embeds.to("cpu", dtype=get_torch_dtype('fp16')) + state_dict[f"pe:{prompt_txt}"] = prompt_embeds.pooled_embeds.to("cpu", + dtype=get_torch_dtype( + 'fp16')) save_file(state_dict, self.rescale_config.prompt_tensors) self.print("Encoding complete.") @@ -158,10 +161,15 @@ class TrainSDRescaleProcess(BaseSDTrainProcess): ] prompt = self.prompt_cache[prompt_txt].to(device=self.device_torch, dtype=dtype) prompt.text_embeds.to(device=self.device_torch, dtype=dtype) - prompt.pooled_embeds.to(device=self.device_torch, dtype=dtype) neutral = self.prompt_cache[""].to(device=self.device_torch, dtype=dtype) neutral.text_embeds.to(device=self.device_torch, dtype=dtype) - neutral.pooled_embeds.to(device=self.device_torch, dtype=dtype) + if hasattr(prompt, 'pooled_embeds') \ + and hasattr(neutral, 'pooled_embeds') \ + and prompt.pooled_embeds is not None \ + and neutral.pooled_embeds is not None: + prompt.pooled_embeds.to(device=self.device_torch, dtype=dtype) + neutral.pooled_embeds.to(device=self.device_torch, dtype=dtype) + if prompt is None: raise ValueError(f"Prompt {prompt_txt} is not in cache") diff --git a/toolkit/pipelines.py b/toolkit/pipelines.py index c65d6df1..f772fa1f 100644 --- a/toolkit/pipelines.py +++ b/toolkit/pipelines.py @@ -1,7 +1,8 @@ from typing import Union, List, Optional, Dict, Any, Tuple, Callable import torch -from diffusers import StableDiffusionXLPipeline +from diffusers import StableDiffusionXLPipeline, StableDiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg @@ -13,10 +14,7 @@ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline): self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, - height: Optional[int] = None, - width: Optional[int] = None, num_inference_steps: int = 50, - denoising_end: Optional[float] = None, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, @@ -28,16 +26,9 @@ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline): negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.0, - original_size: Optional[Tuple[int, int]] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), - target_size: Optional[Tuple[int, int]] = None, - # predict_noise: bool = False, timestep: Optional[int] = None, ): r""" @@ -226,8 +217,6 @@ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline): # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps - # 5. Prepare latent variables num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( @@ -245,7 +234,7 @@ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline): add_text_embeds = pooled_prompt_embeds add_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype - ).to(device) # TODO DOES NOT CAST ORIGINALLY + ).to(device) # TODO DOES NOT CAST ORIGINALLY if do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) @@ -286,3 +275,260 @@ class CustomStableDiffusionXLPipeline(StableDiffusionXLPipeline): print('Called cpu offload', gpu_id) # fuck off pass + + +class CustomStableDiffusionPipeline(StableDiffusionPipeline): + + # replace the call so it matches SDXL call so we can use the same code and also stop early + def __call__( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, + ): + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + # 7.1 Apply denoising_end + if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1: + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + + # some of the inputs are to keep it compatible with sdx + def predict_noise( + self, + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + num_inference_steps: int = 50, + guidance_scale: float = 5.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + crops_coords_top_left: Tuple[int, int] = (0, 0), + timestep: Optional[int] = None, + ): + + # 0. Default height and width to unet + height = self.unet.config.sample_size * self.vae_scale_factor + width = self.unet.config.sample_size * self.vae_scale_factor + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + timestep, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + return noise_pred