import random from typing import Union, Optional import torch from diffusers import T2IAdapter, ControlNetModel from safetensors.torch import load_file from torch import nn from toolkit.basic import value_map from toolkit.clip_vision_adapter import ClipVisionAdapter from toolkit.custom_adapter import CustomAdapter from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO from toolkit.embedding import Embedding from toolkit.guidance import GuidanceType, get_guidance_loss from toolkit.image_utils import reduce_contrast from toolkit.ip_adapter import IPAdapter from toolkit.network_mixins import Network from toolkit.prompt_utils import PromptEmbeds from toolkit.reference_adapter import ReferenceAdapter from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork from toolkit.timer import Timer, DummyTimer from toolkit.train_tools import get_torch_dtype, apply_learnable_snr_gos, apply_snr_weight from toolkit.config_modules import TrainConfig, AdapterConfig AdapterType = Union[ T2IAdapter, IPAdapter, ClipVisionAdapter, ReferenceAdapter, CustomAdapter, ControlNetModel ] class UnifiedTrainingModel(nn.Module): def __init__( self, sd: StableDiffusion, network: Optional[Network] = None, adapter: Optional[AdapterType] = None, assistant_adapter: Optional[AdapterType] = None, train_config: TrainConfig = None, adapter_config: AdapterConfig = None, embedding: Optional[Embedding] = None, timer: Timer = None, trigger_word: Optional[str] = None, gpu_ids: Optional[Union[int, list]] = None, ): super(UnifiedTrainingModel, self).__init__() self.sd: StableDiffusion = sd self.network: Optional[Network] = network self.adapter: Optional[AdapterType] = adapter self.assistant_adapter: Union['T2IAdapter', 'ControlNetModel', None] = assistant_adapter self.train_config: TrainConfig = train_config self.adapter_config: AdapterConfig = adapter_config self.embedding: Optional[Embedding] = embedding self.timer: Timer = timer self.trigger_word: Optional[str] = trigger_word self.device_torch = torch.device("cuda") self.gpu_ids = gpu_ids self.primary_gpu_id = self.gpu_ids[0] # The first in the list is primary # misc config self.do_long_prompts = False self.do_prior_prediction = False self.do_guided_loss = False if self.train_config.do_prior_divergence: self.do_prior_prediction = True # register modules from sd self.text_encoders = nn.ModuleList([self.sd.text_encoder] if not isinstance(self.sd.text_encoder, list) else self.sd.text_encoder) self.unet = self.sd.unet self.vae = self.sd.vae def is_primary_gpu(self): return torch.cuda.current_device() == self.primary_gpu_id def before_unet_predict(self): pass def after_unet_predict(self): pass def get_prior_prediction( self, noisy_latents: torch.Tensor, conditional_embeds: PromptEmbeds, match_adapter_assist: bool, network_weight_list: list, timesteps: torch.Tensor, pred_kwargs: dict, batch: 'DataLoaderBatchDTO', noise: torch.Tensor, unconditional_embeds: Optional[PromptEmbeds] = None, conditioned_prompts=None, **kwargs ): # todo for embeddings, we need to run without trigger words was_unet_training = self.sd.unet.training was_network_active = False if self.network is not None: was_network_active = self.network.is_active self.network.is_active = False can_disable_adapter = False was_adapter_active = False if self.adapter is not None and (isinstance(self.adapter, IPAdapter) or isinstance(self.adapter, ReferenceAdapter) or (isinstance(self.adapter, CustomAdapter)) ): can_disable_adapter = True was_adapter_active = self.adapter.is_active self.adapter.is_active = False # do a prediction here so we can match its output with network multiplier set to 0.0 with torch.no_grad(): dtype = get_torch_dtype(self.train_config.dtype) embeds_to_use = conditional_embeds.clone().detach() # handle clip vision adapter by removing triggers from prompt and replacing with the class name if (self.adapter is not None and isinstance(self.adapter, ClipVisionAdapter)) or self.embedding is not None: prompt_list = batch.get_caption_list() class_name = '' triggers = ['[trigger]', '[name]'] remove_tokens = [] if self.embed_config is not None: triggers.append(self.embed_config.trigger) for i in range(1, self.embed_config.tokens): remove_tokens.append(f"{self.embed_config.trigger}_{i}") if self.embed_config.trigger_class_name is not None: class_name = self.embed_config.trigger_class_name if self.adapter is not None: triggers.append(self.adapter_config.trigger) for i in range(1, self.adapter_config.num_tokens): remove_tokens.append(f"{self.adapter_config.trigger}_{i}") if self.adapter_config.trigger_class_name is not None: class_name = self.adapter_config.trigger_class_name for idx, prompt in enumerate(prompt_list): for remove_token in remove_tokens: prompt = prompt.replace(remove_token, '') for trigger in triggers: prompt = prompt.replace(trigger, class_name) prompt_list[idx] = prompt embeds_to_use = self.sd.encode_prompt( prompt_list, long_prompts=self.do_long_prompts).to( self.device_torch, dtype=dtype).detach() # dont use network on this # self.network.multiplier = 0.0 self.sd.unet.eval() if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not self.sd.is_flux: # we need to remove the image embeds from the prompt except for flux embeds_to_use: PromptEmbeds = embeds_to_use.clone().detach() end_pos = embeds_to_use.text_embeds.shape[1] - self.adapter_config.num_tokens embeds_to_use.text_embeds = embeds_to_use.text_embeds[:, :end_pos, :] if unconditional_embeds is not None: unconditional_embeds = unconditional_embeds.clone().detach() unconditional_embeds.text_embeds = unconditional_embeds.text_embeds[:, :end_pos] if unconditional_embeds is not None: unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach() prior_pred = self.sd.predict_noise( latents=noisy_latents.to(self.device_torch, dtype=dtype).detach(), conditional_embeddings=embeds_to_use.to(self.device_torch, dtype=dtype).detach(), unconditional_embeddings=unconditional_embeds, timestep=timesteps, guidance_scale=self.train_config.cfg_scale, rescale_cfg=self.train_config.cfg_rescale, **pred_kwargs # adapter residuals in here ) if was_unet_training: self.sd.unet.train() prior_pred = prior_pred.detach() # remove the residuals as we wont use them on prediction when matching control if match_adapter_assist and 'down_intrablock_additional_residuals' in pred_kwargs: del pred_kwargs['down_intrablock_additional_residuals'] if match_adapter_assist and 'down_block_additional_residuals' in pred_kwargs: del pred_kwargs['down_block_additional_residuals'] if match_adapter_assist and 'mid_block_additional_residual' in pred_kwargs: del pred_kwargs['mid_block_additional_residual'] if can_disable_adapter: self.adapter.is_active = was_adapter_active # restore network # self.network.multiplier = network_weight_list if self.network is not None: self.network.is_active = was_network_active return prior_pred def predict_noise( self, noisy_latents: torch.Tensor, timesteps: Union[int, torch.Tensor] = 1, conditional_embeds: Union[PromptEmbeds, None] = None, unconditional_embeds: Union[PromptEmbeds, None] = None, **kwargs, ): dtype = get_torch_dtype(self.train_config.dtype) return self.sd.predict_noise( latents=noisy_latents.to(self.device_torch, dtype=dtype), conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype), unconditional_embeddings=unconditional_embeds, timestep=timesteps, guidance_scale=self.train_config.cfg_scale, detach_unconditional=False, rescale_cfg=self.train_config.cfg_rescale, **kwargs ) def get_guided_loss( self, noisy_latents: torch.Tensor, conditional_embeds: PromptEmbeds, match_adapter_assist: bool, network_weight_list: list, timesteps: torch.Tensor, pred_kwargs: dict, batch: 'DataLoaderBatchDTO', noise: torch.Tensor, unconditional_embeds: Optional[PromptEmbeds] = None, **kwargs ): loss = get_guidance_loss( noisy_latents=noisy_latents, conditional_embeds=conditional_embeds, match_adapter_assist=match_adapter_assist, network_weight_list=network_weight_list, timesteps=timesteps, pred_kwargs=pred_kwargs, batch=batch, noise=noise, sd=self.sd, unconditional_embeds=unconditional_embeds, scaler=self.scaler, **kwargs ) return loss def get_noise(self, latents, batch_size, dtype=torch.float32): # get noise noise = self.sd.get_latent_noise( height=latents.shape[2], width=latents.shape[3], batch_size=batch_size, noise_offset=self.train_config.noise_offset, ).to(self.device_torch, dtype=dtype) if self.train_config.random_noise_shift > 0.0: # get random noise -1 to 1 noise_shift = torch.rand((noise.shape[0], noise.shape[1], 1, 1), device=noise.device, dtype=noise.dtype) * 2 - 1 # multiply by shift amount noise_shift *= self.train_config.random_noise_shift # add to noise noise += noise_shift # standardize the noise std = noise.std(dim=(2, 3), keepdim=True) normalizer = 1 / (std + 1e-6) noise = noise * normalizer return noise def calculate_loss( self, noise_pred: torch.Tensor, noise: torch.Tensor, noisy_latents: torch.Tensor, timesteps: torch.Tensor, batch: 'DataLoaderBatchDTO', mask_multiplier: Union[torch.Tensor, float] = 1.0, prior_pred: Union[torch.Tensor, None] = None, **kwargs ): loss_target = self.train_config.loss_target is_reg = any(batch.get_is_reg_list()) prior_mask_multiplier = None target_mask_multiplier = None dtype = get_torch_dtype(self.train_config.dtype) has_mask = batch.mask_tensor is not None with torch.no_grad(): loss_multiplier = torch.tensor(batch.loss_multiplier_list).to(self.device_torch, dtype=torch.float32) if self.train_config.match_noise_norm: # match the norm of the noise noise_norm = torch.linalg.vector_norm(noise, ord=2, dim=(1, 2, 3), keepdim=True) noise_pred_norm = torch.linalg.vector_norm(noise_pred, ord=2, dim=(1, 2, 3), keepdim=True) noise_pred = noise_pred * (noise_norm / noise_pred_norm) if self.train_config.pred_scaler != 1.0: noise_pred = noise_pred * self.train_config.pred_scaler target = None if self.train_config.target_noise_multiplier != 1.0: noise = noise * self.train_config.target_noise_multiplier if self.train_config.correct_pred_norm or (self.train_config.inverted_mask_prior and prior_pred is not None and has_mask): if self.train_config.correct_pred_norm and not is_reg: with torch.no_grad(): # this only works if doing a prior pred if prior_pred is not None: prior_mean = prior_pred.mean([2,3], keepdim=True) prior_std = prior_pred.std([2,3], keepdim=True) noise_mean = noise_pred.mean([2,3], keepdim=True) noise_std = noise_pred.std([2,3], keepdim=True) mean_adjust = prior_mean - noise_mean std_adjust = prior_std - noise_std mean_adjust = mean_adjust * self.train_config.correct_pred_norm_multiplier std_adjust = std_adjust * self.train_config.correct_pred_norm_multiplier target_mean = noise_mean + mean_adjust target_std = noise_std + std_adjust eps = 1e-5 # match the noise to the prior noise = (noise - noise_mean) / (noise_std + eps) noise = noise * (target_std + eps) + target_mean noise = noise.detach() if self.train_config.inverted_mask_prior and prior_pred is not None and has_mask: assert not self.train_config.train_turbo with torch.no_grad(): # we need to make the noise prediction be a masked blending of noise and prior_pred stretched_mask_multiplier = value_map( mask_multiplier, batch.file_items[0].dataset_config.mask_min_value, 1.0, 0.0, 1.0 ) prior_mask_multiplier = 1.0 - stretched_mask_multiplier # target_mask_multiplier = mask_multiplier # mask_multiplier = 1.0 target = noise # target = (noise * mask_multiplier) + (prior_pred * prior_mask_multiplier) # set masked multiplier to 1.0 so we dont double apply it # mask_multiplier = 1.0 elif prior_pred is not None and not self.train_config.do_prior_divergence: assert not self.train_config.train_turbo # matching adapter prediction target = prior_pred elif self.sd.prediction_type == 'v_prediction': # v-parameterization training target = self.sd.noise_scheduler.get_velocity(batch.tensor, noise, timesteps) elif self.sd.is_flow_matching: target = (noise - batch.latents).detach() else: target = noise if target is None: target = noise pred = noise_pred if self.train_config.train_turbo: raise ValueError("Turbo training is not supported in MultiGPUSDTrainer") ignore_snr = False if loss_target == 'source' or loss_target == 'unaugmented': assert not self.train_config.train_turbo # ignore_snr = True if batch.sigmas is None: raise ValueError("Batch sigmas is None. This should not happen") # src https://github.com/huggingface/diffusers/blob/324d18fba23f6c9d7475b0ff7c777685f7128d40/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1190 denoised_latents = noise_pred * (-batch.sigmas) + noisy_latents weighing = batch.sigmas ** -2.0 if loss_target == 'source': # denoise the latent and compare to the latent in the batch target = batch.latents elif loss_target == 'unaugmented': # we have to encode images into latents for now # we also denoise as the unaugmented tensor is not a noisy diffirental with torch.no_grad(): unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor).to(self.device_torch, dtype=dtype) unaugmented_latents = unaugmented_latents * self.train_config.latent_multiplier target = unaugmented_latents.detach() # Get the target for loss depending on the prediction type if self.sd.noise_scheduler.config.prediction_type == "epsilon": target = target # we are computing loss against denoise latents elif self.sd.noise_scheduler.config.prediction_type == "v_prediction": target = self.sd.noise_scheduler.get_velocity(target, noise, timesteps) else: raise ValueError(f"Unknown prediction type {self.sd.noise_scheduler.config.prediction_type}") # mse loss without reduction loss_per_element = (weighing.float() * (denoised_latents.float() - target.float()) ** 2) loss = loss_per_element else: if self.train_config.loss_type == "mae": loss = torch.nn.functional.l1_loss(pred.float(), target.float(), reduction="none") else: loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none") # handle linear timesteps and only adjust the weight of the timesteps if self.sd.is_flow_matching and self.train_config.linear_timesteps: # calculate the weights for the timesteps timestep_weight = self.sd.noise_scheduler.get_weights_for_timesteps(timesteps).to(loss.device, dtype=loss.dtype) timestep_weight = timestep_weight.view(-1, 1, 1, 1).detach() loss = loss * timestep_weight if self.train_config.do_prior_divergence and prior_pred is not None: loss = loss + (torch.nn.functional.mse_loss(pred.float(), prior_pred.float(), reduction="none") * -1.0) if self.train_config.train_turbo: mask_multiplier = mask_multiplier[:, 3:, :, :] # resize to the size of the loss mask_multiplier = torch.nn.functional.interpolate(mask_multiplier, size=(pred.shape[2], pred.shape[3]), mode='nearest') # multiply by our mask loss = loss * mask_multiplier prior_loss = None if self.train_config.inverted_mask_prior and prior_pred is not None and prior_mask_multiplier is not None: assert not self.train_config.train_turbo if self.train_config.loss_type == "mae": prior_loss = torch.nn.functional.l1_loss(pred.float(), prior_pred.float(), reduction="none") else: prior_loss = torch.nn.functional.mse_loss(pred.float(), prior_pred.float(), reduction="none") prior_loss = prior_loss * prior_mask_multiplier * self.train_config.inverted_mask_prior_multiplier if torch.isnan(prior_loss).any(): print("Prior loss is nan") prior_loss = None else: prior_loss = prior_loss.mean([1, 2, 3]) # loss = loss + prior_loss # loss = loss + prior_loss # loss = loss + prior_loss loss = loss.mean([1, 2, 3]) # apply loss multiplier before prior loss loss = loss * loss_multiplier if prior_loss is not None: loss = loss + prior_loss if not self.train_config.train_turbo: if self.train_config.learnable_snr_gos: # add snr_gamma loss = apply_learnable_snr_gos(loss, timesteps, self.snr_gos) elif self.train_config.snr_gamma is not None and self.train_config.snr_gamma > 0.000001 and not ignore_snr: # add snr_gamma loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.snr_gamma, fixed=True) elif self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001 and not ignore_snr: # add min_snr_gamma loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma) loss = loss.mean() # check for additional losses if self.adapter is not None and hasattr(self.adapter, "additional_loss") and self.adapter.additional_loss is not None: loss = loss + self.adapter.additional_loss.mean() self.adapter.additional_loss = None if self.train_config.target_norm_std: # seperate out the batch and channels pred_std = noise_pred.std([2, 3], keepdim=True) norm_std_loss = torch.abs(self.train_config.target_norm_std_value - pred_std).mean() loss = loss + norm_std_loss return loss def preprocess_batch(self, batch: 'DataLoaderBatchDTO'): return batch def process_general_training_batch(self, batch: 'DataLoaderBatchDTO'): with torch.no_grad(): with self.timer('prepare_prompt'): prompts = batch.get_caption_list() is_reg_list = batch.get_is_reg_list() is_any_reg = any([is_reg for is_reg in is_reg_list]) do_double = self.train_config.short_and_long_captions and not is_any_reg if self.train_config.short_and_long_captions and do_double: # dont do this with regs. No point # double batch and add short captions to the end prompts = prompts + batch.get_caption_short_list() is_reg_list = is_reg_list + is_reg_list if self.sd.model_config.refiner_name_or_path is not None and self.train_config.train_unet: prompts = prompts + prompts is_reg_list = is_reg_list + is_reg_list conditioned_prompts = [] for prompt, is_reg in zip(prompts, is_reg_list): # make sure the embedding is in the prompts if self.embedding is not None: prompt = self.embedding.inject_embedding_to_prompt( prompt, expand_token=True, add_if_not_present=not is_reg, ) if self.adapter and isinstance(self.adapter, ClipVisionAdapter): prompt = self.adapter.inject_trigger_into_prompt( prompt, expand_token=True, add_if_not_present=not is_reg, ) # make sure trigger is in the prompts if not a regularization run if self.trigger_word is not None: prompt = self.sd.inject_trigger_into_prompt( prompt, trigger=self.trigger_word, add_if_not_present=not is_reg, ) if not is_reg and self.train_config.prompt_saturation_chance > 0.0: # do random prompt saturation by expanding the prompt to hit at least 77 tokens if random.random() < self.train_config.prompt_saturation_chance: est_num_tokens = len(prompt.split(' ')) if est_num_tokens < 77: num_repeats = int(77 / est_num_tokens) + 1 prompt = ', '.join([prompt] * num_repeats) conditioned_prompts.append(prompt) with self.timer('prepare_latents'): dtype = get_torch_dtype(self.train_config.dtype) imgs = None is_reg = any(batch.get_is_reg_list()) if batch.tensor is not None: imgs = batch.tensor imgs = imgs.to(self.device_torch, dtype=dtype) # dont adjust for regs. if self.train_config.img_multiplier is not None and not is_reg: # do it ad contrast imgs = reduce_contrast(imgs, self.train_config.img_multiplier) if batch.latents is not None: latents = batch.latents.to(self.device_torch, dtype=dtype) batch.latents = latents else: # normalize to if self.train_config.standardize_images: if self.sd.is_xl or self.sd.is_vega or self.sd.is_ssd: target_mean_list = [0.0002, -0.1034, -0.1879] target_std_list = [0.5436, 0.5116, 0.5033] else: target_mean_list = [-0.0739, -0.1597, -0.2380] target_std_list = [0.5623, 0.5295, 0.5347] # Mean: tensor([-0.0739, -0.1597, -0.2380]) # Standard Deviation: tensor([0.5623, 0.5295, 0.5347]) imgs_channel_mean = imgs.mean(dim=(2, 3), keepdim=True) imgs_channel_std = imgs.std(dim=(2, 3), keepdim=True) imgs = (imgs - imgs_channel_mean) / imgs_channel_std target_mean = torch.tensor(target_mean_list, device=self.device_torch, dtype=dtype) target_std = torch.tensor(target_std_list, device=self.device_torch, dtype=dtype) # expand them to match dim target_mean = target_mean.unsqueeze(0).unsqueeze(2).unsqueeze(3) target_std = target_std.unsqueeze(0).unsqueeze(2).unsqueeze(3) imgs = imgs * target_std + target_mean batch.tensor = imgs # show_tensors(imgs, 'imgs') latents = self.sd.encode_images(imgs) batch.latents = latents if self.train_config.standardize_latents: if self.sd.is_xl or self.sd.is_vega or self.sd.is_ssd: target_mean_list = [-0.1075, 0.0231, -0.0135, 0.2164] target_std_list = [0.8979, 0.7505, 0.9150, 0.7451] else: target_mean_list = [0.2949, -0.3188, 0.0807, 0.1929] target_std_list = [0.8560, 0.9629, 0.7778, 0.6719] latents_channel_mean = latents.mean(dim=(2, 3), keepdim=True) latents_channel_std = latents.std(dim=(2, 3), keepdim=True) latents = (latents - latents_channel_mean) / latents_channel_std target_mean = torch.tensor(target_mean_list, device=self.device_torch, dtype=dtype) target_std = torch.tensor(target_std_list, device=self.device_torch, dtype=dtype) # expand them to match dim target_mean = target_mean.unsqueeze(0).unsqueeze(2).unsqueeze(3) target_std = target_std.unsqueeze(0).unsqueeze(2).unsqueeze(3) latents = latents * target_std + target_mean batch.latents = latents # show_latents(latents, self.sd.vae, 'latents') if batch.unconditional_tensor is not None and batch.unconditional_latents is None: unconditional_imgs = batch.unconditional_tensor unconditional_imgs = unconditional_imgs.to(self.device_torch, dtype=dtype) unconditional_latents = self.sd.encode_images(unconditional_imgs) batch.unconditional_latents = unconditional_latents * self.train_config.latent_multiplier unaugmented_latents = None if self.train_config.loss_target == 'differential_noise': # we determine noise from the differential of the latents unaugmented_latents = self.sd.encode_images(batch.unaugmented_tensor) batch_size = len(batch.file_items) min_noise_steps = self.train_config.min_denoising_steps max_noise_steps = self.train_config.max_denoising_steps if self.sd.model_config.refiner_name_or_path is not None: # if we are not training the unet, then we are only doing refiner and do not need to double up if self.train_config.train_unet: max_noise_steps = round(self.train_config.max_denoising_steps * self.sd.model_config.refiner_start_at) do_double = True else: min_noise_steps = round(self.train_config.max_denoising_steps * self.sd.model_config.refiner_start_at) do_double = False with self.timer('prepare_noise'): num_train_timesteps = self.train_config.num_train_timesteps if self.train_config.noise_scheduler in ['custom_lcm']: # we store this value on our custom one self.sd.noise_scheduler.set_timesteps( self.sd.noise_scheduler.train_timesteps, device=self.device_torch ) elif self.train_config.noise_scheduler in ['lcm']: self.sd.noise_scheduler.set_timesteps( num_train_timesteps, device=self.device_torch, original_inference_steps=num_train_timesteps ) elif self.train_config.noise_scheduler == 'flowmatch': self.sd.noise_scheduler.set_train_timesteps( num_train_timesteps, device=self.device_torch, linear=self.train_config.linear_timesteps ) else: self.sd.noise_scheduler.set_timesteps( num_train_timesteps, device=self.device_torch ) content_or_style = self.train_config.content_or_style if is_reg: content_or_style = self.train_config.content_or_style_reg # if self.train_config.timestep_sampling == 'style' or self.train_config.timestep_sampling == 'content': if content_or_style in ['style', 'content']: # this is from diffusers training code # Cubic sampling for favoring later or earlier timesteps # For more details about why cubic sampling is used for content / structure, # refer to section 3.4 of https://arxiv.org/abs/2302.08453 # for content / structure, it is best to favor earlier timesteps # for style, it is best to favor later timesteps orig_timesteps = torch.rand((batch_size,), device=latents.device) if content_or_style == 'content': timestep_indices = orig_timesteps ** 3 * self.train_config.num_train_timesteps elif content_or_style == 'style': timestep_indices = (1 - orig_timesteps ** 3) * self.train_config.num_train_timesteps timestep_indices = value_map( timestep_indices, 0, self.train_config.num_train_timesteps - 1, min_noise_steps, max_noise_steps - 1 ) timestep_indices = timestep_indices.long().clamp( min_noise_steps + 1, max_noise_steps - 1 ) elif content_or_style == 'balanced': if min_noise_steps == max_noise_steps: timestep_indices = torch.ones((batch_size,), device=self.device_torch) * min_noise_steps else: # todo, some schedulers use indices, otheres use timesteps. Not sure what to do here timestep_indices = torch.randint( min_noise_steps + 1, max_noise_steps - 1, (batch_size,), device=self.device_torch ) timestep_indices = timestep_indices.long() else: raise ValueError(f"Unknown content_or_style {content_or_style}") # do flow matching # if self.sd.is_flow_matching: # u = compute_density_for_timestep_sampling( # weighting_scheme="logit_normal", # ["sigma_sqrt", "logit_normal", "mode", "cosmap"] # batch_size=batch_size, # logit_mean=0.0, # logit_std=1.0, # mode_scale=1.29, # ) # timestep_indices = (u * self.sd.noise_scheduler.config.num_train_timesteps).long() # convert the timestep_indices to a timestep timesteps = [self.sd.noise_scheduler.timesteps[x.item()] for x in timestep_indices] timesteps = torch.stack(timesteps, dim=0) # get noise noise = self.get_noise(latents, batch_size, dtype=dtype) # add dynamic noise offset. Dynamic noise is offsetting the noise to the same channelwise mean as the latents # this will negate any noise offsets if self.train_config.dynamic_noise_offset and not is_reg: latents_channel_mean = latents.mean(dim=(2, 3), keepdim=True) / 2 # subtract channel mean to that we compensate for the mean of the latents on the noise offset per channel noise = noise + latents_channel_mean if self.train_config.loss_target == 'differential_noise': differential = latents - unaugmented_latents # add noise to differential # noise = noise + differential noise = noise + (differential * 0.5) # noise = value_map(differential, 0, torch.abs(differential).max(), 0, torch.abs(noise).max()) latents = unaugmented_latents noise_multiplier = self.train_config.noise_multiplier noise = noise * noise_multiplier latent_multiplier = self.train_config.latent_multiplier # handle adaptive scaling mased on std if self.train_config.adaptive_scaling_factor: std = latents.std(dim=(2, 3), keepdim=True) normalizer = 1 / (std + 1e-6) latent_multiplier = normalizer latents = latents * latent_multiplier batch.latents = latents # normalize latents to a mean of 0 and an std of 1 # mean_zero_latents = latents - latents.mean() # latents = mean_zero_latents / mean_zero_latents.std() if batch.unconditional_latents is not None: batch.unconditional_latents = batch.unconditional_latents * self.train_config.latent_multiplier noisy_latents = self.sd.add_noise(latents, noise, timesteps) # determine scaled noise # todo do we need to scale this or does it always predict full intensity # noise = noisy_latents - latents # https://github.com/huggingface/diffusers/blob/324d18fba23f6c9d7475b0ff7c777685f7128d40/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170C17-L1171C77 if self.train_config.loss_target == 'source' or self.train_config.loss_target == 'unaugmented': sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype) # add it to the batch batch.sigmas = sigmas # todo is this for sdxl? find out where this came from originally # noisy_latents = noisy_latents / ((sigmas ** 2 + 1) ** 0.5) def double_up_tensor(tensor: torch.Tensor): if tensor is None: return None return torch.cat([tensor, tensor], dim=0) if do_double: if self.sd.model_config.refiner_name_or_path: # apply refiner double up refiner_timesteps = torch.randint( max_noise_steps, self.train_config.max_denoising_steps, (batch_size,), device=self.device_torch ) refiner_timesteps = refiner_timesteps.long() # add our new timesteps on to end timesteps = torch.cat([timesteps, refiner_timesteps], dim=0) refiner_noisy_latents = self.sd.noise_scheduler.add_noise(latents, noise, refiner_timesteps) noisy_latents = torch.cat([noisy_latents, refiner_noisy_latents], dim=0) else: # just double it noisy_latents = double_up_tensor(noisy_latents) timesteps = double_up_tensor(timesteps) noise = double_up_tensor(noise) # prompts are already updated above imgs = double_up_tensor(imgs) batch.mask_tensor = double_up_tensor(batch.mask_tensor) batch.control_tensor = double_up_tensor(batch.control_tensor) noisy_latent_multiplier = self.train_config.noisy_latent_multiplier if noisy_latent_multiplier != 1.0: noisy_latents = noisy_latents * noisy_latent_multiplier # remove grads for these noisy_latents.requires_grad = False noisy_latents = noisy_latents.detach() noise.requires_grad = False noise = noise.detach() return noisy_latents, noise, timesteps, conditioned_prompts, imgs def forward(self, batch: DataLoaderBatchDTO): if not self.is_primary_gpu(): # replace timer with dummy one self.timer = DummyTimer() self.device_torch = torch.cuda.current_device() self.timer.start('preprocess_batch') batch = self.preprocess_batch(batch) dtype = get_torch_dtype(self.train_config.dtype) # sanity check if self.sd.vae.dtype != self.sd.vae_torch_dtype: self.sd.vae = self.sd.vae.to(self.sd.vae_torch_dtype) if isinstance(self.sd.text_encoder, list): for encoder in self.sd.text_encoder: if encoder.dtype != self.sd.te_torch_dtype: encoder.to(self.sd.te_torch_dtype) else: if self.sd.text_encoder.dtype != self.sd.te_torch_dtype: self.sd.text_encoder.to(self.sd.te_torch_dtype) noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch) if self.train_config.do_cfg or self.train_config.do_random_cfg: # pick random negative prompts if self.negative_prompt_pool is not None: negative_prompts = [] for i in range(noisy_latents.shape[0]): num_neg = random.randint(1, self.train_config.max_negative_prompts) this_neg_prompts = [random.choice(self.negative_prompt_pool) for _ in range(num_neg)] this_neg_prompt = ', '.join(this_neg_prompts) negative_prompts.append(this_neg_prompt) self.batch_negative_prompt = negative_prompts else: self.batch_negative_prompt = ['' for _ in range(batch.latents.shape[0])] if self.adapter and isinstance(self.adapter, CustomAdapter): # condition the prompt # todo handle more than one adapter image self.adapter.num_control_images = 1 conditioned_prompts = self.adapter.condition_prompt(conditioned_prompts) network_weight_list = batch.get_network_weight_list() if self.train_config.single_item_batching: network_weight_list = network_weight_list + network_weight_list has_adapter_img = batch.control_tensor is not None has_clip_image = batch.clip_image_tensor is not None has_clip_image_embeds = batch.clip_image_embeds is not None # force it to be true if doing regs as we handle those differently if any([batch.file_items[idx].is_reg for idx in range(len(batch.file_items))]): has_clip_image = True if self._clip_image_embeds_unconditional is not None: has_clip_image_embeds = True # we are caching embeds, handle that differently has_clip_image = False if self.adapter is not None and isinstance(self.adapter, IPAdapter) and not has_clip_image and has_adapter_img: raise ValueError( "IPAdapter control image is now 'clip_image_path' instead of 'control_path'. Please update your dataset config ") match_adapter_assist = False # check if we are matching the adapter assistant if self.assistant_adapter: if self.train_config.match_adapter_chance == 1.0: match_adapter_assist = True elif self.train_config.match_adapter_chance > 0.0: match_adapter_assist = torch.rand( (1,), device=self.device_torch, dtype=dtype ) < self.train_config.match_adapter_chance self.timer.stop('preprocess_batch') is_reg = False with torch.no_grad(): loss_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype) for idx, file_item in enumerate(batch.file_items): if file_item.is_reg: loss_multiplier[idx] = loss_multiplier[idx] * self.train_config.reg_weight is_reg = True adapter_images = None sigmas = None if has_adapter_img and (self.adapter or self.assistant_adapter): with self.timer('get_adapter_images'): # todo move this to data loader if batch.control_tensor is not None: adapter_images = batch.control_tensor.to(self.device_torch, dtype=dtype).detach() # match in channels if self.assistant_adapter is not None: in_channels = self.assistant_adapter.config.in_channels if adapter_images.shape[1] != in_channels: # we need to match the channels adapter_images = adapter_images[:, :in_channels, :, :] else: raise NotImplementedError("Adapter images now must be loaded with dataloader") clip_images = None if has_clip_image: with self.timer('get_clip_images'): # todo move this to data loader if batch.clip_image_tensor is not None: clip_images = batch.clip_image_tensor.to(self.device_torch, dtype=dtype).detach() mask_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype) if batch.mask_tensor is not None: with self.timer('get_mask_multiplier'): # upsampling no supported for bfloat16 mask_multiplier = batch.mask_tensor.to(self.device_torch, dtype=torch.float16).detach() # scale down to the size of the latents, mask multiplier shape(bs, 1, width, height), noisy_latents shape(bs, channels, width, height) mask_multiplier = torch.nn.functional.interpolate( mask_multiplier, size=(noisy_latents.shape[2], noisy_latents.shape[3]) ) # expand to match latents mask_multiplier = mask_multiplier.expand(-1, noisy_latents.shape[1], -1, -1) mask_multiplier = mask_multiplier.to(self.device_torch, dtype=dtype).detach() def get_adapter_multiplier(): if self.adapter and isinstance(self.adapter, T2IAdapter): # training a t2i adapter, not using as assistant. return 1.0 elif match_adapter_assist: # training a texture. We want it high adapter_strength_min = 0.9 adapter_strength_max = 1.0 else: # training with assistance, we want it low # adapter_strength_min = 0.4 # adapter_strength_max = 0.7 adapter_strength_min = 0.5 adapter_strength_max = 1.1 adapter_conditioning_scale = torch.rand( (1,), device=self.device_torch, dtype=dtype ) adapter_conditioning_scale = value_map( adapter_conditioning_scale, 0.0, 1.0, adapter_strength_min, adapter_strength_max ) return adapter_conditioning_scale # flush() with self.timer('grad_setup'): # text encoding grad_on_text_encoder = False if self.train_config.train_text_encoder: grad_on_text_encoder = True if self.embedding is not None: grad_on_text_encoder = True if self.adapter and isinstance(self.adapter, ClipVisionAdapter): grad_on_text_encoder = True if self.adapter_config and self.adapter_config.type == 'te_augmenter': grad_on_text_encoder = True # have a blank network so we can wrap it in a context and set multipliers without checking every time if self.network is not None: network = self.network else: network = BlankNetwork() # set the weights network.multiplier = network_weight_list # activate network if it exits prompts_1 = conditioned_prompts prompts_2 = None if self.train_config.short_and_long_captions_encoder_split and self.sd.is_xl: prompts_1 = batch.get_caption_short_list() prompts_2 = conditioned_prompts # make the batch splits if self.train_config.single_item_batching: if self.sd.model_config.refiner_name_or_path is not None: raise ValueError("Single item batching is not supported when training the refiner") batch_size = noisy_latents.shape[0] # chunk/split everything noisy_latents_list = torch.chunk(noisy_latents, batch_size, dim=0) noise_list = torch.chunk(noise, batch_size, dim=0) timesteps_list = torch.chunk(timesteps, batch_size, dim=0) conditioned_prompts_list = [[prompt] for prompt in prompts_1] if imgs is not None: imgs_list = torch.chunk(imgs, batch_size, dim=0) else: imgs_list = [None for _ in range(batch_size)] if adapter_images is not None: adapter_images_list = torch.chunk(adapter_images, batch_size, dim=0) else: adapter_images_list = [None for _ in range(batch_size)] if clip_images is not None: clip_images_list = torch.chunk(clip_images, batch_size, dim=0) else: clip_images_list = [None for _ in range(batch_size)] mask_multiplier_list = torch.chunk(mask_multiplier, batch_size, dim=0) if prompts_2 is None: prompt_2_list = [None for _ in range(batch_size)] else: prompt_2_list = [[prompt] for prompt in prompts_2] else: noisy_latents_list = [noisy_latents] noise_list = [noise] timesteps_list = [timesteps] conditioned_prompts_list = [prompts_1] imgs_list = [imgs] adapter_images_list = [adapter_images] clip_images_list = [clip_images] mask_multiplier_list = [mask_multiplier] if prompts_2 is None: prompt_2_list = [None] else: prompt_2_list = [prompts_2] for noisy_latents, noise, timesteps, conditioned_prompts, imgs, adapter_images, clip_images, mask_multiplier, prompt_2 in zip( noisy_latents_list, noise_list, timesteps_list, conditioned_prompts_list, imgs_list, adapter_images_list, clip_images_list, mask_multiplier_list, prompt_2_list ): with (network): # encode clip adapter here so embeds are active for tokenizer if self.adapter and isinstance(self.adapter, ClipVisionAdapter): with self.timer('encode_clip_vision_embeds'): if has_clip_image: conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( clip_images.detach().to(self.device_torch, dtype=dtype), is_training=True, has_been_preprocessed=True ) else: # just do a blank one conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( torch.zeros( (noisy_latents.shape[0], 3, 512, 512), device=self.device_torch, dtype=dtype ), is_training=True, has_been_preprocessed=True, drop=True ) # it will be injected into the tokenizer when called self.adapter(conditional_clip_embeds) # do the custom adapter after the prior prediction if self.adapter and isinstance(self.adapter, CustomAdapter) and has_clip_image: quad_count = random.randint(1, 4) self.adapter.train() self.adapter.trigger_pre_te( tensors_0_1=clip_images if not is_reg else None, # on regs we send none to get random noise is_training=True, has_been_preprocessed=True, quad_count=quad_count, batch_size=noisy_latents.shape[0] ) with self.timer('encode_prompt'): unconditional_embeds = None if grad_on_text_encoder: with torch.set_grad_enabled(True): if isinstance(self.adapter, CustomAdapter): self.adapter.is_unconditional_run = False conditional_embeds = self.sd.encode_prompt( conditioned_prompts, prompt_2, dropout_prob=self.train_config.prompt_dropout_prob, long_prompts=self.do_long_prompts).to( self.device_torch, dtype=dtype) if self.train_config.do_cfg: if isinstance(self.adapter, CustomAdapter): self.adapter.is_unconditional_run = True # todo only do one and repeat it unconditional_embeds = self.sd.encode_prompt( self.batch_negative_prompt, self.batch_negative_prompt, dropout_prob=self.train_config.prompt_dropout_prob, long_prompts=self.do_long_prompts).to( self.device_torch, dtype=dtype) if isinstance(self.adapter, CustomAdapter): self.adapter.is_unconditional_run = False else: with torch.set_grad_enabled(False): # make sure it is in eval mode if isinstance(self.sd.text_encoder, list): for te in self.sd.text_encoder: te.eval() else: self.sd.text_encoder.eval() if isinstance(self.adapter, CustomAdapter): self.adapter.is_unconditional_run = False conditional_embeds = self.sd.encode_prompt( conditioned_prompts, prompt_2, dropout_prob=self.train_config.prompt_dropout_prob, long_prompts=self.do_long_prompts).to( self.device_torch, dtype=dtype) if self.train_config.do_cfg: if isinstance(self.adapter, CustomAdapter): self.adapter.is_unconditional_run = True unconditional_embeds = self.sd.encode_prompt( self.batch_negative_prompt, dropout_prob=self.train_config.prompt_dropout_prob, long_prompts=self.do_long_prompts).to( self.device_torch, dtype=dtype) if isinstance(self.adapter, CustomAdapter): self.adapter.is_unconditional_run = False # detach the embeddings conditional_embeds = conditional_embeds.detach() if self.train_config.do_cfg: unconditional_embeds = unconditional_embeds.detach() # flush() pred_kwargs = {} if has_adapter_img: if (self.adapter and isinstance(self.adapter, T2IAdapter)) or ( self.assistant_adapter and isinstance(self.assistant_adapter, T2IAdapter)): with torch.set_grad_enabled(self.adapter is not None): adapter = self.assistant_adapter if self.assistant_adapter is not None else self.adapter adapter_multiplier = get_adapter_multiplier() with self.timer('encode_adapter'): down_block_additional_residuals = adapter(adapter_images) if self.assistant_adapter: # not training. detach down_block_additional_residuals = [ sample.to(dtype=dtype).detach() * adapter_multiplier for sample in down_block_additional_residuals ] else: down_block_additional_residuals = [ sample.to(dtype=dtype) * adapter_multiplier for sample in down_block_additional_residuals ] pred_kwargs['down_intrablock_additional_residuals'] = down_block_additional_residuals if self.adapter and isinstance(self.adapter, IPAdapter): with self.timer('encode_adapter_embeds'): # number of images to do if doing a quad image quad_count = random.randint(1, 4) image_size = self.adapter.input_size if has_clip_image_embeds: # todo handle reg images better than this if is_reg: # get unconditional image embeds from cache embeds = [ load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in range(noisy_latents.shape[0]) ] conditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache( embeds, quad_count=quad_count ) if self.train_config.do_cfg: embeds = [ load_file(random.choice(batch.clip_image_embeds_unconditional)) for i in range(noisy_latents.shape[0]) ] unconditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache( embeds, quad_count=quad_count ) else: conditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache( batch.clip_image_embeds, quad_count=quad_count ) if self.train_config.do_cfg: unconditional_clip_embeds = self.adapter.parse_clip_image_embeds_from_cache( batch.clip_image_embeds_unconditional, quad_count=quad_count ) elif is_reg: # we will zero it out in the img embedder clip_images = torch.zeros( (noisy_latents.shape[0], 3, image_size, image_size), device=self.device_torch, dtype=dtype ).detach() # drop will zero it out conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( clip_images, drop=True, is_training=True, has_been_preprocessed=False, quad_count=quad_count ) if self.train_config.do_cfg: unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( torch.zeros( (noisy_latents.shape[0], 3, image_size, image_size), device=self.device_torch, dtype=dtype ).detach(), is_training=True, drop=True, has_been_preprocessed=False, quad_count=quad_count ) elif has_clip_image: conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( clip_images.detach().to(self.device_torch, dtype=dtype), is_training=True, has_been_preprocessed=True, quad_count=quad_count, # do cfg on clip embeds to normalize the embeddings for when doing cfg # cfg_embed_strength=3.0 if not self.train_config.do_cfg else None # cfg_embed_strength=3.0 if not self.train_config.do_cfg else None ) if self.train_config.do_cfg: unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( clip_images.detach().to(self.device_torch, dtype=dtype), is_training=True, drop=True, has_been_preprocessed=True, quad_count=quad_count ) else: print("No Clip Image") print([file_item.path for file_item in batch.file_items]) raise ValueError("Could not find clip image") if not self.adapter_config.train_image_encoder: # we are not training the image encoder, so we need to detach the embeds conditional_clip_embeds = conditional_clip_embeds.detach() if self.train_config.do_cfg: unconditional_clip_embeds = unconditional_clip_embeds.detach() with self.timer('encode_adapter'): self.adapter.train() conditional_embeds = self.adapter( conditional_embeds.detach(), conditional_clip_embeds, is_unconditional=False ) if self.train_config.do_cfg: unconditional_embeds = self.adapter( unconditional_embeds.detach(), unconditional_clip_embeds, is_unconditional=True ) else: # wipe out unconsitional self.adapter.last_unconditional = None if self.adapter and isinstance(self.adapter, ReferenceAdapter): # pass in our scheduler self.adapter.noise_scheduler = self.lr_scheduler if has_clip_image or has_adapter_img: img_to_use = clip_images if has_clip_image else adapter_images # currently 0-1 needs to be -1 to 1 reference_images = ((img_to_use - 0.5) * 2).detach().to(self.device_torch, dtype=dtype) self.adapter.set_reference_images(reference_images) self.adapter.noise_scheduler = self.sd.noise_scheduler elif is_reg: self.adapter.set_blank_reference_images(noisy_latents.shape[0]) else: self.adapter.set_reference_images(None) prior_pred = None do_reg_prior = False # if is_reg and (self.network is not None or self.adapter is not None): # # we are doing a reg image and we have a network or adapter # do_reg_prior = True do_inverted_masked_prior = False if self.train_config.inverted_mask_prior and batch.mask_tensor is not None: do_inverted_masked_prior = True do_correct_pred_norm_prior = self.train_config.correct_pred_norm do_guidance_prior = False if batch.unconditional_latents is not None: # for this not that, we need a prior pred to normalize guidance_type: GuidanceType = batch.file_items[0].dataset_config.guidance_type if guidance_type == 'tnt': do_guidance_prior = True if (( has_adapter_img and self.assistant_adapter and match_adapter_assist) or self.do_prior_prediction or do_guidance_prior or do_reg_prior or do_inverted_masked_prior or self.train_config.correct_pred_norm): with self.timer('prior predict'): prior_pred = self.get_prior_prediction( noisy_latents=noisy_latents, conditional_embeds=conditional_embeds, match_adapter_assist=match_adapter_assist, network_weight_list=network_weight_list, timesteps=timesteps, pred_kwargs=pred_kwargs, noise=noise, batch=batch, unconditional_embeds=unconditional_embeds, conditioned_prompts=conditioned_prompts ) if prior_pred is not None: prior_pred = prior_pred.detach() # do the custom adapter after the prior prediction if self.adapter and isinstance(self.adapter, CustomAdapter) and has_clip_image: quad_count = random.randint(1, 4) self.adapter.train() conditional_embeds = self.adapter.condition_encoded_embeds( tensors_0_1=clip_images, prompt_embeds=conditional_embeds, is_training=True, has_been_preprocessed=True, quad_count=quad_count ) if self.train_config.do_cfg and unconditional_embeds is not None: unconditional_embeds = self.adapter.condition_encoded_embeds( tensors_0_1=clip_images, prompt_embeds=unconditional_embeds, is_training=True, has_been_preprocessed=True, is_unconditional=True, quad_count=quad_count ) if self.adapter and isinstance(self.adapter, CustomAdapter) and batch.extra_values is not None: self.adapter.add_extra_values(batch.extra_values.detach()) if self.train_config.do_cfg: self.adapter.add_extra_values(torch.zeros_like(batch.extra_values.detach()), is_unconditional=True) if has_adapter_img: if (self.adapter and isinstance(self.adapter, ControlNetModel)) or ( self.assistant_adapter and isinstance(self.assistant_adapter, ControlNetModel)): if self.train_config.do_cfg: raise ValueError("ControlNetModel is not supported with CFG") with torch.set_grad_enabled(self.adapter is not None): adapter: ControlNetModel = self.assistant_adapter if self.assistant_adapter is not None else self.adapter adapter_multiplier = get_adapter_multiplier() with self.timer('encode_adapter'): # add_text_embeds is pooled_prompt_embeds for sdxl added_cond_kwargs = {} if self.sd.is_xl: added_cond_kwargs["text_embeds"] = conditional_embeds.pooled_embeds added_cond_kwargs['time_ids'] = self.sd.get_time_ids_from_latents(noisy_latents) down_block_res_samples, mid_block_res_sample = adapter( noisy_latents, timesteps, encoder_hidden_states=conditional_embeds.text_embeds, controlnet_cond=adapter_images, conditioning_scale=1.0, guess_mode=False, added_cond_kwargs=added_cond_kwargs, return_dict=False, ) pred_kwargs['down_block_additional_residuals'] = down_block_res_samples pred_kwargs['mid_block_additional_residual'] = mid_block_res_sample self.before_unet_predict() # do a prior pred if we have an unconditional image, we will swap out the giadance later if batch.unconditional_latents is not None or self.do_guided_loss: # do guided loss loss = self.get_guided_loss( noisy_latents=noisy_latents, conditional_embeds=conditional_embeds, match_adapter_assist=match_adapter_assist, network_weight_list=network_weight_list, timesteps=timesteps, pred_kwargs=pred_kwargs, batch=batch, noise=noise, unconditional_embeds=unconditional_embeds, mask_multiplier=mask_multiplier, prior_pred=prior_pred, ) else: with self.timer('predict_unet'): if unconditional_embeds is not None: unconditional_embeds = unconditional_embeds.to(self.device_torch, dtype=dtype).detach() noise_pred = self.predict_noise( noisy_latents=noisy_latents.to(self.device_torch, dtype=dtype), timesteps=timesteps, conditional_embeds=conditional_embeds.to(self.device_torch, dtype=dtype), unconditional_embeds=unconditional_embeds, **pred_kwargs ) self.after_unet_predict() with self.timer('calculate_loss'): noise = noise.to(self.device_torch, dtype=dtype).detach() loss = self.calculate_loss( noise_pred=noise_pred, noise=noise, noisy_latents=noisy_latents, timesteps=timesteps, batch=batch, mask_multiplier=mask_multiplier, prior_pred=prior_pred, ) loss = loss * loss_multiplier.mean() return loss