diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 9d3aeec3..f798cbc8 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -29,6 +29,7 @@ import gc import torch from jobs.process import BaseSDTrainProcess from torchvision import transforms +import math @@ -366,7 +367,29 @@ class SDTrainer(BaseSDTrainProcess): loss_per_element = (weighing.float() * (denoised_latents.float() - target.float()) ** 2) loss = loss_per_element else: - if self.train_config.loss_type == "mae": + # handle flow matching ref https://github.com/huggingface/diffusers/blob/ec068f9b5bf7c65f93125ec889e0ff1792a00da1/examples/dreambooth/train_dreambooth_lora_sd3.py#L1485C17-L1495C100 + if self.sd.is_v3: + target = noisy_latents.detach() + bsz = pred.shape[0] + # todo implement others + # weighing_scheme = + # 3 just do mode for now? + # if args.weighting_scheme == "sigma_sqrt": + sigmas = self.sd.noise_scheduler.get_sigmas(timesteps, pred.ndim, dtype, self.device_torch) + weighting = (sigmas ** -2.0).float() + # elif args.weighting_scheme == "logit_normal": + # # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$). + # u = torch.normal(mean=args.logit_mean, std=args.logit_std, size=(bsz,), device=accelerator.device) + # weighting = torch.nn.functional.sigmoid(u) + # elif args.weighting_scheme == "mode": + # mode_scale = 1.29 + # See sec 3.1 in the SD3 paper (20). + # u = torch.rand(size=(bsz,), device=pred.device) + # weighting = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u) + + loss = (weighting.float() * (pred.float() - target.float()) ** 2).reshape(target.shape[0], -1) + + elif self.train_config.loss_type == "mae": loss = torch.nn.functional.l1_loss(pred.float(), target.float(), reduction="none") else: loss = torch.nn.functional.mse_loss(pred.float(), target.float(), reduction="none") diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 34f5d09f..c538a215 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1244,6 +1244,7 @@ class BaseSDTrainProcess(BaseTrainProcess): conv_alpha=self.network_config.conv_alpha, is_sdxl=self.model_config.is_xl or self.model_config.is_ssd, is_v2=self.model_config.is_v2, + is_v3=self.model_config.is_v3, is_ssd=self.model_config.is_ssd, is_vega=self.model_config.is_vega, dropout=self.network_config.dropout, diff --git a/requirements.txt b/requirements.txt index 19cdbcf4..ce63f11b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ torch torchvision safetensors -diffusers==0.26.3 +diffusers transformers lycoris-lora==1.8.3 flatten_json diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index bd948913..75c4d753 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -336,6 +336,7 @@ class ModelConfig: self.is_xl: bool = kwargs.get('is_xl', False) self.is_pixart: bool = kwargs.get('is_pixart', False) self.is_pixart_sigma: bool = kwargs.get('is_pixart', False) + self.is_v3: bool = kwargs.get('is_v3', False) if self.is_pixart_sigma: self.is_pixart = True self.is_ssd: bool = kwargs.get('is_ssd', False) diff --git a/toolkit/dataloader_mixins.py b/toolkit/dataloader_mixins.py index 045107a5..ac96a3a2 100644 --- a/toolkit/dataloader_mixins.py +++ b/toolkit/dataloader_mixins.py @@ -1353,6 +1353,8 @@ class LatentCachingMixin: file_item.latent_space_version = self.sd.model_config.latent_space_version elif self.sd.is_xl: file_item.latent_space_version = 'sdxl' + elif self.sd.is_v3: + file_item.latent_space_version = 'sd3' else: file_item.latent_space_version = 'sd1' file_item.is_caching_to_disk = to_disk diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index 409fef2d..af917026 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -152,6 +152,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): train_unet: Optional[bool] = True, is_sdxl=False, is_v2=False, + is_v3=False, is_pixart: bool = False, use_bias: bool = False, is_lorm: bool = False, @@ -200,6 +201,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): self.multiplier = multiplier self.is_sdxl = is_sdxl self.is_v2 = is_v2 + self.is_v3 = is_v3 self.is_pixart = is_pixart self.network_type = network_type if self.network_type.lower() == "dora": @@ -233,7 +235,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): target_replace_modules: List[torch.nn.Module], ) -> List[LoRAModule]: unet_prefix = self.LORA_PREFIX_UNET - if is_pixart: + if is_pixart or is_v3: unet_prefix = f"lora_transformer" prefix = ( @@ -346,6 +348,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): if modules_dim is not None or self.conv_lora_dim is not None or conv_block_dims is not None: target_modules += target_conv_modules + if is_v3: + target_modules = ["SD3Transformer2DModel"] + if train_unet: self.unet_loras, skipped_un = create_modules(True, None, unet, target_modules) else: diff --git a/toolkit/models/ilora.py b/toolkit/models/ilora.py index 93478acc..97914990 100644 --- a/toolkit/models/ilora.py +++ b/toolkit/models/ilora.py @@ -48,7 +48,7 @@ class LoRAGenerator(torch.nn.Module): head_size: int = 512, num_mlp_layers: int = 1, output_size: int = 768, - dropout: float = 0.5 + dropout: float = 0.0 ): super().__init__() self.input_size = input_size @@ -131,8 +131,12 @@ class InstantLoRAMidModule(torch.nn.Module): x_chunk = x_chunks[i] # reshape weight_chunk = weight_chunk.view(self.down_shape) - # run a simple lenear layer with the down weight - x_chunk = x_chunk @ weight_chunk.T + # check if is conv or linear + if len(weight_chunk.shape) == 4: + x_chunk = nn.functional.conv2d(x_chunk, weight_chunk) + else: + # run a simple linear layer with the down weight + x_chunk = x_chunk @ weight_chunk.T x_out.append(x_chunk) x = torch.cat(x_out, dim=0) return x @@ -158,8 +162,12 @@ class InstantLoRAMidModule(torch.nn.Module): x_chunk = x_chunks[i] # reshape weight_chunk = weight_chunk.view(self.up_shape) - # run a simple lenear layer with the down weight - x_chunk = x_chunk @ weight_chunk.T + # check if is conv or linear + if len(weight_chunk.shape) == 4: + x_chunk = nn.functional.conv2d(x_chunk, weight_chunk) + else: + # run a simple linear layer with the down weight + x_chunk = x_chunk @ weight_chunk.T x_out.append(x_chunk) x = torch.cat(x_out, dim=0) return x diff --git a/toolkit/reference_adapter.py b/toolkit/reference_adapter.py index 90995ffc..d00dfb72 100644 --- a/toolkit/reference_adapter.py +++ b/toolkit/reference_adapter.py @@ -4,7 +4,6 @@ import torch import sys from PIL import Image -from diffusers.models.unet_2d_condition import UNet2DConditionOutput from torch.nn import Parameter from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection diff --git a/toolkit/sampler.py b/toolkit/sampler.py index dae8b99b..e6c6e32e 100644 --- a/toolkit/sampler.py +++ b/toolkit/sampler.py @@ -13,9 +13,12 @@ from diffusers import ( HeunDiscreteScheduler, KDPM2DiscreteScheduler, KDPM2AncestralDiscreteScheduler, - LCMScheduler + LCMScheduler, + FlowMatchEulerDiscreteScheduler, ) +from toolkit.samplers.custom_flowmatch_sampler import CustomFlowMatchEulerDiscreteScheduler + from k_diffusion.external import CompVisDenoiser from toolkit.samplers.custom_lcm_scheduler import CustomLCMScheduler @@ -112,6 +115,15 @@ def get_sampler( scheduler_cls = LCMScheduler elif sampler == "custom_lcm": scheduler_cls = CustomLCMScheduler + elif sampler == "flowmatch": + scheduler_cls = CustomFlowMatchEulerDiscreteScheduler + config_to_use = { + "_class_name": "FlowMatchEulerDiscreteScheduler", + "_diffusers_version": "0.29.0.dev0", + "num_train_timesteps": 1000, + "shift": 3.0 + } + config = copy.deepcopy(config_to_use) config.update(sched_init_args) diff --git a/toolkit/samplers/custom_flowmatch_sampler.py b/toolkit/samplers/custom_flowmatch_sampler.py new file mode 100644 index 00000000..1d5750ad --- /dev/null +++ b/toolkit/samplers/custom_flowmatch_sampler.py @@ -0,0 +1,32 @@ +from typing import Union + +from diffusers import FlowMatchEulerDiscreteScheduler +import torch + +class CustomFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): + + def get_sigmas(self, timesteps: torch.Tensor, n_dim, dtype, device) -> torch.Tensor: + sigmas = self.sigmas.to(device=device, dtype=dtype) + schedule_timesteps = self.timesteps.to(device) + timesteps = timesteps.to(device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + + return sigma + + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.Tensor, + ) -> torch.Tensor: + n_dim = original_samples.ndim + sigmas = self.get_sigmas(timesteps, n_dim, original_samples.dtype, original_samples.device) + noisy_model_input = sigmas * noise + (1.0 - sigmas) * original_samples + return noisy_model_input + + def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + return sample \ No newline at end of file diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 3990c13b..d792d359 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -40,13 +40,13 @@ from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffu from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \ StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, \ StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny, ControlNetModel, \ - StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline + StableDiffusionXLControlNetPipeline, StableDiffusionControlNetPipeline, StableDiffusion3Pipeline, StableDiffusion3Img2ImgPipeline import diffusers from diffusers import \ AutoencoderKL, \ UNet2DConditionModel from diffusers import PixArtAlphaPipeline, DPMSolverMultistepScheduler -from transformers import T5EncoderModel +from transformers import T5EncoderModel, BitsAndBytesConfig from toolkit.util.pixart_sigma_patch import pixart_sigma_init_patched_inputs, PixArtSigmaPipeline from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT @@ -147,6 +147,7 @@ class StableDiffusion: self.is_xl = model_config.is_xl self.is_v2 = model_config.is_v2 self.is_ssd = model_config.is_ssd + self.is_v3 = model_config.is_v3 self.is_vega = model_config.is_vega self.is_pixart = model_config.is_pixart @@ -236,6 +237,64 @@ class StableDiffusion: te1_state_dict['text_projection.weight'] = replacement_weight.to(self.device_torch, dtype=dtype) flush() print("Injecting alt weights") + elif self.model_config.is_v3: + if self.custom_pipeline is not None: + pipln = self.custom_pipeline + else: + pipln = StableDiffusion3Pipeline + + quantization_config = BitsAndBytesConfig(load_in_8bit=True) + + model_id = "stabilityai/stable-diffusion-3-medium" + text_encoder3 = T5EncoderModel.from_pretrained( + model_id, + subfolder="text_encoder_3", + # quantization_config=quantization_config, + revision="refs/pr/26", + device_map="cuda" + ) + + # see if path exists + if not os.path.exists(model_path) or os.path.isdir(model_path): + try: + # try to load with default diffusers + pipe = pipln.from_pretrained( + model_path, + dtype=dtype, + device=self.device_torch, + text_encoder_3=text_encoder3, + # variant="fp16", + use_safetensors=True, + revision="refs/pr/26", + repo_type="model", + ignore_patterns=["*.md", "*..gitattributes"], + **load_args + ) + except Exception as e: + print(f"Error loading from pretrained: {e}") + raise e + + else: + pipe = pipln.from_single_file( + model_path, + device=self.device_torch, + torch_dtype=self.torch_dtype, + text_encoder_3=text_encoder3, + ) + + flush() + + text_encoders = [pipe.text_encoder, pipe.text_encoder_2, pipe.text_encoder_3] + tokenizer = [pipe.tokenizer, pipe.tokenizer_2, pipe.tokenizer_3] + # replace the to function with a no-op since it throws an error instead of a warning + # text_encoders[2].to = lambda *args, **kwargs: None + for text_encoder in text_encoders: + text_encoder.to(self.device_torch, dtype=dtype) + text_encoder.requires_grad_(False) + text_encoder.eval() + text_encoder = text_encoders + + elif self.model_config.is_pixart: te_kwargs = {} # handle quantization of TE @@ -361,8 +420,8 @@ class StableDiffusion: # add hacks to unet to help training # pipe.unet = prepare_unet_for_training(pipe.unet) - if self.is_pixart: - # pixart doesnt use a unet + if self.is_pixart or self.is_v3: + # pixart and sd3 dont use a unet self.unet = pipe.transformer else: self.unet: 'UNet2DConditionModel' = pipe.unet @@ -487,6 +546,8 @@ class StableDiffusion: Pipe = StableDiffusionKDiffusionXLPipeline elif self.is_xl: Pipe = StableDiffusionXLPipeline + elif self.is_v3: + Pipe = StableDiffusion3Pipeline else: Pipe = StableDiffusionPipeline @@ -515,15 +576,30 @@ class StableDiffusion: if self.is_xl: pipeline = Pipe( vae=self.vae, - unet=self.unet, + transformer=self.unet, text_encoder=self.text_encoder[0], text_encoder_2=self.text_encoder[1], + text_encoder_3=self.text_encoder[2], tokenizer=self.tokenizer[0], tokenizer_2=self.tokenizer[1], + tokenizer_3=self.tokenizer[2], scheduler=noise_scheduler, **extra_args ).to(self.device_torch) pipeline.watermark = None + elif self.is_v3: + pipeline = Pipe( + vae=self.vae, + transformer=self.unet, + text_encoder=self.text_encoder[0], + text_encoder_2=self.text_encoder[1], + text_encoder_3=self.text_encoder[2], + tokenizer=self.tokenizer[0], + tokenizer_2=self.tokenizer[1], + tokenizer_3=self.tokenizer[2], + scheduler=noise_scheduler, + **extra_args + ) elif self.is_pixart: pipeline = PixArtAlphaPipeline( vae=self.vae, @@ -576,7 +652,7 @@ class StableDiffusion: if self.network is not None: start_multiplier = self.network.multiplier - pipeline.to(self.device_torch) + # pipeline.to(self.device_torch) with network: with torch.no_grad(): @@ -744,6 +820,19 @@ class StableDiffusion: latents=gen_config.latents, **extra ).images[0] + elif self.is_v3: + img = pipeline( + prompt_embeds=conditional_embeds.text_embeds, + pooled_prompt_embeds=conditional_embeds.pooled_embeds, + negative_prompt_embeds=unconditional_embeds.text_embeds, + negative_pooled_prompt_embeds=unconditional_embeds.pooled_embeds, + height=gen_config.height, + width=gen_config.width, + num_inference_steps=gen_config.num_inference_steps, + guidance_scale=gen_config.guidance_scale, + latents=gen_config.latents, + **extra + ).images[0] elif self.is_pixart: # needs attention masks for some reason img = pipeline( @@ -1004,6 +1093,20 @@ class StableDiffusion: ) return torch.cat(out_chunks, dim=0) + def precondition_model_outputs_sd3(model_output, model_input, timestep_tensor): + mo_chunks = torch.chunk(model_output, model_output.shape[0], dim=0) + mi_chunks = torch.chunk(model_input, model_input.shape[0], dim=0) + timestep_chunks = torch.chunk(timestep_tensor, timestep_tensor.shape[0], dim=0) + out_chunks = [] + # unsqueeze if timestep is zero dim + for idx in range(model_output.shape[0]): + sigmas = self.noise_scheduler.get_sigmas(timestep_chunks[idx], n_dim=model_output.ndim, dtype=model_output.dtype, device=model_output.device) + # Follow: Section 5 of https://arxiv.org/abs/2206.00364. + # Preconditioning of the model outputs. + out = mo_chunks[idx] * (-sigmas) + mi_chunks[idx] + out_chunks.append(out) + return torch.cat(out_chunks, dim=0) + if self.is_xl: with torch.no_grad(): # 16, 6 for bs of 4 @@ -1177,12 +1280,22 @@ class StableDiffusion: self.unet.to(self.device_torch) if self.unet.dtype != self.torch_dtype: self.unet = self.unet.to(dtype=self.torch_dtype) - noise_pred = self.unet( - latent_model_input.to(self.device_torch, self.torch_dtype), - timestep, - encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype), - **kwargs, - ).sample + if self.is_v3: + noise_pred = self.unet( + hidden_states=latent_model_input.to(self.device_torch, self.torch_dtype), + timestep=timestep, + encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype), + pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, self.torch_dtype), + **kwargs, + ).sample + noise_pred = precondition_model_outputs_sd3(noise_pred, latent_model_input, timestep) + else: + noise_pred = self.unet( + latent_model_input.to(self.device_torch, self.torch_dtype), + timestep=timestep, + encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, self.torch_dtype), + **kwargs, + ).sample conditional_pred = noise_pred @@ -1343,6 +1456,19 @@ class StableDiffusion: dropout_prob=dropout_prob, ) ) + if self.is_v3: + return PromptEmbeds( + train_tools.encode_prompts_sd3( + self.tokenizer, + self.text_encoder, + prompt, + num_images_per_prompt=num_images_per_prompt, + truncate=not long_prompts, + max_length=max_length, + dropout_prob=dropout_prob, + pipeline=self.pipeline, + ) + ) elif self.is_pixart: embeds, attention_mask = train_tools.encode_prompts_pixart( self.tokenizer, @@ -1735,7 +1861,7 @@ class StableDiffusion: def save_device_state(self): # saves the current device state for all modules # this is useful for when we want to alter the state and restore it - if self.is_pixart: + if self.is_pixart or self.is_v3: unet_has_grad = self.unet.proj_out.weight.requires_grad else: unet_has_grad = self.unet.conv_in.weight.requires_grad @@ -1755,11 +1881,15 @@ class StableDiffusion: if isinstance(self.text_encoder, list): self.device_state['text_encoder']: List[dict] = [] for encoder in self.text_encoder: + try: + te_has_grad = encoder.text_model.final_layer_norm.weight.requires_grad + except: + te_has_grad = encoder.encoder.block[0].layer[0].SelfAttention.q.weight.requires_grad self.device_state['text_encoder'].append({ 'training': encoder.training, 'device': encoder.device, # todo there has to be a better way to do this - 'requires_grad': encoder.text_model.final_layer_norm.weight.requires_grad + 'requires_grad': te_has_grad }) else: if isinstance(self.text_encoder, T5EncoderModel): diff --git a/toolkit/train_tools.py b/toolkit/train_tools.py index 1bae37b5..db15e63f 100644 --- a/toolkit/train_tools.py +++ b/toolkit/train_tools.py @@ -25,6 +25,7 @@ from diffusers import ( HeunDiscreteScheduler, KDPM2DiscreteScheduler, KDPM2AncestralDiscreteScheduler, + StableDiffusion3Pipeline ) from library.lpw_stable_diffusion import StableDiffusionLongPromptWeightingPipeline import torch @@ -580,6 +581,58 @@ def encode_prompts_xl( return torch.concat(text_embeds_list, dim=-1), pooled_text_embeds +def encode_prompts_sd3( + tokenizers: list['CLIPTokenizer'], + text_encoders: list[Union['CLIPTextModel', 'CLIPTextModelWithProjection', T5EncoderModel]], + prompts: list[str], + num_images_per_prompt: int = 1, + truncate: bool = True, + max_length=None, + dropout_prob=0.0, + pipeline: StableDiffusion3Pipeline = None, +): + text_embeds_list = [] + pooled_text_embeds = None # always text_encoder_2's pool + + prompt_2 = prompts + prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2 + + prompt_3 = prompts + prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3 + + device = text_encoders[0].device + + prompt_embed, pooled_prompt_embed = pipeline._get_clip_prompt_embeds( + prompt=prompts, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=0, + ) + prompt_2_embed, pooled_prompt_2_embed = pipeline._get_clip_prompt_embeds( + prompt=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + clip_skip=None, + clip_model_index=1, + ) + clip_prompt_embeds = torch.cat([prompt_embed, prompt_2_embed], dim=-1) + + t5_prompt_embed = pipeline._get_t5_prompt_embeds( + prompt=prompt_3, + num_images_per_prompt=num_images_per_prompt, + device=device + ) + + clip_prompt_embeds = torch.nn.functional.pad( + clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) + ) + + prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) + pooled_prompt_embeds = torch.cat([pooled_prompt_embed, pooled_prompt_2_embed], dim=-1) + + return prompt_embeds, pooled_prompt_embeds + # ref for long prompts https://github.com/huggingface/diffusers/issues/2136 def text_encode(text_encoder: 'CLIPTextModel', tokens, truncate: bool = True, max_length=None): @@ -720,18 +773,22 @@ def concat_embeddings( def add_all_snr_to_noise_scheduler(noise_scheduler, device): - if hasattr(noise_scheduler, "all_snr"): - return - # compute it - with torch.no_grad(): - alphas_cumprod = noise_scheduler.alphas_cumprod - sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) - sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) - alpha = sqrt_alphas_cumprod - sigma = sqrt_one_minus_alphas_cumprod - all_snr = (alpha / sigma) ** 2 - all_snr.requires_grad = False - noise_scheduler.all_snr = all_snr.to(device) + try: + if hasattr(noise_scheduler, "all_snr"): + return + # compute it + with torch.no_grad(): + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) + sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) + alpha = sqrt_alphas_cumprod + sigma = sqrt_one_minus_alphas_cumprod + all_snr = (alpha / sigma) ** 2 + all_snr.requires_grad = False + noise_scheduler.all_snr = all_snr.to(device) + except Exception as e: + print(e) + print("Failed to add all_snr to noise_scheduler") def get_all_snr(noise_scheduler, device):