From 7fed4ea7615c165d875c9a5b6ea80fb827e5af01 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Wed, 14 Aug 2024 10:14:13 -0600 Subject: [PATCH] fixed huge flux training bug. Added ability to use an assistatn lora --- toolkit/assistant_lora.py | 55 ++++++++++++++ toolkit/lora_special.py | 2 + toolkit/network_mixins.py | 2 +- toolkit/stable_diffusion_model.py | 114 +++++++++++++++++------------- 4 files changed, 124 insertions(+), 49 deletions(-) create mode 100644 toolkit/assistant_lora.py diff --git a/toolkit/assistant_lora.py b/toolkit/assistant_lora.py new file mode 100644 index 00000000..cdeca968 --- /dev/null +++ b/toolkit/assistant_lora.py @@ -0,0 +1,55 @@ +from typing import TYPE_CHECKING +from toolkit.config_modules import NetworkConfig +from toolkit.lora_special import LoRASpecialNetwork +from safetensors.torch import load_file + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion + + +def load_assistant_lora_from_path(adapter_path, sd: 'StableDiffusion') -> LoRASpecialNetwork: + if not sd.is_flux: + raise ValueError("Only Flux models can load assistant adapters currently.") + pipe = sd.pipeline + print(f"Loading assistant adapter from {adapter_path}") + adapter_name = adapter_path.split("/")[-1].split(".")[0] + lora_state_dict = load_file(adapter_path) + + linear_dim = int(lora_state_dict['transformer.single_transformer_blocks.0.attn.to_k.lora_A.weight'].shape[0]) + # linear_alpha = int(lora_state_dict['lora_transformer_single_transformer_blocks_0_attn_to_k.alpha'].item()) + linear_alpha = linear_dim + transformer_only = 'transformer.proj_out.alpha' not in lora_state_dict + # get dim and scale + network_config = NetworkConfig( + linear=linear_dim, + linear_alpha=linear_alpha, + transformer_only=transformer_only, + ) + + network = LoRASpecialNetwork( + text_encoder=pipe.text_encoder, + unet=pipe.transformer, + lora_dim=network_config.linear, + multiplier=1.0, + alpha=network_config.linear_alpha, + train_unet=True, + train_text_encoder=False, + is_flux=True, + network_config=network_config, + network_type=network_config.type, + transformer_only=network_config.transformer_only, + is_assistant_adapter=True + ) + network.apply_to( + pipe.text_encoder, + pipe.transformer, + apply_text_encoder=False, + apply_unet=True + ) + network.force_to(sd.device_torch, dtype=sd.torch_dtype) + network.eval() + network._update_torch_multiplier() + network.load_weights(lora_state_dict) + network.is_active = True + + return network diff --git a/toolkit/lora_special.py b/toolkit/lora_special.py index b407618a..80f139c3 100644 --- a/toolkit/lora_special.py +++ b/toolkit/lora_special.py @@ -175,6 +175,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): full_train_in_out: bool = False, transformer_only: bool = False, peft_format: bool = False, + is_assistant_adapter: bool = False, **kwargs ) -> None: """ @@ -223,6 +224,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork): self.is_auraflow = is_auraflow self.is_flux = is_flux self.network_type = network_type + self.is_assistant_adapter = is_assistant_adapter if self.network_type.lower() == "dora": self.module_class = DoRAModule module_class = DoRAModule diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index 620c9852..14ac6161 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -263,7 +263,7 @@ class ToolkitModuleMixin: if isinstance(x, QTensor): x = x.dequantize() # always cast to float32 - lora_input = x.float() + lora_input = x.to(self.lora_down.weight.dtype) lora_output = self._call_forward(lora_input) multiplier = self.network_ref().torch_multiplier diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 0896080f..51724130 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -11,7 +11,8 @@ from collections import OrderedDict import copy import yaml from PIL import Image -from diffusers.pipelines.pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_1024_BIN, ASPECT_RATIO_512_BIN, ASPECT_RATIO_2048_BIN, ASPECT_RATIO_256_BIN +from diffusers.pipelines.pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_1024_BIN, ASPECT_RATIO_512_BIN, \ + ASPECT_RATIO_2048_BIN, ASPECT_RATIO_256_BIN from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import rescale_noise_cfg from safetensors.torch import save_file, load_file from torch import autocast @@ -20,6 +21,7 @@ from torch.utils.checkpoint import checkpoint from tqdm import tqdm from torchvision.transforms import Resize, transforms +from toolkit.assistant_lora import load_assistant_lora_from_path from toolkit.clip_vision_adapter import ClipVisionAdapter from toolkit.custom_adapter import CustomAdapter from toolkit.ip_adapter import IPAdapter @@ -57,6 +59,10 @@ from toolkit.paths import ORIG_CONFIGS_ROOT, DIFFUSERS_CONFIGS_ROOT from toolkit.util.inverse_cfg import inverse_classifier_guidance from optimum.quanto import freeze, qfloat8, quantize, QTensor, qint4 +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from toolkit.lora_special import LoRASpecialNetwork # tell it to shut up diffusers.logging.set_verbosity(diffusers.logging.ERROR) @@ -84,7 +90,6 @@ DO_NOT_TRAIN_WEIGHTS = [ DeviceStatePreset = Literal['cache_latents', 'generate'] - class BlankNetwork: def __init__(self): @@ -127,10 +132,12 @@ class StableDiffusion: self.torch_dtype = get_torch_dtype(dtype) self.device_torch = torch.device(self.device) - self.vae_device_torch = torch.device(self.device) if model_config.vae_device is None else torch.device(model_config.vae_device) + self.vae_device_torch = torch.device(self.device) if model_config.vae_device is None else torch.device( + model_config.vae_device) self.vae_torch_dtype = get_torch_dtype(model_config.vae_dtype) - self.te_device_torch = torch.device(self.device) if model_config.te_device is None else torch.device(model_config.te_device) + self.te_device_torch = torch.device(self.device) if model_config.te_device is None else torch.device( + model_config.te_device) self.te_torch_dtype = get_torch_dtype(model_config.te_dtype) self.model_config = model_config @@ -146,6 +153,7 @@ class StableDiffusion: self.noise_scheduler: Union[None, 'DDPMScheduler'] = noise_scheduler self.refiner_unet: Union[None, 'UNet2DConditionModel'] = None + self.assistant_lora: Union[None, 'LoRASpecialNetwork'] = None # sdxl stuff self.logit_scale = None @@ -270,7 +278,7 @@ class StableDiffusion: # see if path exists if not os.path.exists(model_path) or os.path.isdir(model_path): try: - # try to load with default diffusers + # try to load with default diffusers pipe = pipln.from_pretrained( model_path, dtype=dtype, @@ -377,7 +385,7 @@ class StableDiffusion: device=self.device_torch, **load_args ).to(self.device_torch) - + if self.model_config.unet_sample_size is not None: pipe.transformer.config.sample_size = self.model_config.unet_sample_size pipe.transformer = pipe.transformer.to(self.device_torch, dtype=dtype) @@ -462,6 +470,8 @@ class StableDiffusion: print("Loading transformer") subfolder = 'transformer' transformer_path = model_path + local_files_only = False + # check if HF_DATASETS_OFFLINE or TRANSFORMERS_OFFLINE is set if os.path.exists(transformer_path): subfolder = None transformer_path = os.path.join(transformer_path, 'transformer') @@ -518,7 +528,8 @@ class StableDiffusion: print("Loading t5") tokenizer_2 = T5TokenizerFast.from_pretrained(base_model_path, subfolder="tokenizer_2", torch_dtype=dtype) - text_encoder_2 = T5EncoderModel.from_pretrained(base_model_path, subfolder="text_encoder_2", torch_dtype=dtype) + text_encoder_2 = T5EncoderModel.from_pretrained(base_model_path, subfolder="text_encoder_2", + torch_dtype=dtype) text_encoder_2.to(self.device_torch, dtype=dtype) flush() @@ -655,21 +666,17 @@ class StableDiffusion: # unfortunately, not an easier way with peft pipe.unload_lora_weights() - if self.model_config.assistant_lora_path is not None: - if self.model_config.lora_path is not None: - raise ValueError("Cannot have both lora and assistant lora") - print("Loading assistant lora") - pipe.load_lora_weights(self.model_config.assistant_lora_path, adapter_name="assistant_lora") - pipe.fuse_lora(lora_scale=1.0) - # unfortunately, not an easier way with peft - pipe.unload_lora_weights() - self.tokenizer = tokenizer self.text_encoder = text_encoder self.pipeline = pipe self.load_refiner() self.is_loaded = True + if self.model_config.assistant_lora_path is not None: + print("Loading assistant lora") + self.assistant_lora: 'LoRASpecialNetwork' = load_assistant_lora_from_path( + self.model_config.assistant_lora_path, self) + if self.is_pixart and self.vae_scale_factor == 16: # TODO make our own pipeline? # we generate an image 2x larger, so we need to copy the sizes from larger ones down @@ -741,9 +748,7 @@ class StableDiffusion: if self.model_config.assistant_lora_path is not None: print("Unloading asistant lora") # unfortunately, not an easier way with peft - self.pipeline.load_lora_weights(self.model_config.assistant_lora_path, adapter_name="assistant_lora") - self.pipeline.fuse_lora(lora_scale=-1.0) - self.pipeline.unload_lora_weights() + self.assistant_lora.is_active = False if self.network is not None: self.network.eval() @@ -1027,7 +1032,6 @@ class StableDiffusion: if self.adapter is not None and isinstance(self.adapter, IPAdapter) \ and gen_config.adapter_image_path is not None: - # apply the image projection conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image) unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image, @@ -1035,7 +1039,8 @@ class StableDiffusion: conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds) unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds) - if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and validation_image is not None: + if self.adapter is not None and isinstance(self.adapter, + CustomAdapter) and validation_image is not None: conditional_embeds = self.adapter.condition_encoded_embeds( tensors_0_1=validation_image, prompt_embeds=conditional_embeds, @@ -1052,13 +1057,14 @@ class StableDiffusion: is_generating_samples=True, ) - if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and len(gen_config.extra_values) > 0: - extra_values = torch.tensor([gen_config.extra_values], device=self.device_torch, dtype=self.torch_dtype) + if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and len( + gen_config.extra_values) > 0: + extra_values = torch.tensor([gen_config.extra_values], device=self.device_torch, + dtype=self.torch_dtype) # apply extra values to the embeddings self.adapter.add_extra_values(extra_values, is_unconditional=False) self.adapter.add_extra_values(torch.zeros_like(extra_values), is_unconditional=True) - pass # todo remove, for debugging - + pass # todo remove, for debugging if self.refiner_unet is not None and gen_config.refiner_start_at < 1.0: # if we have a refiner loaded, set the denoising end at the refiner start @@ -1148,9 +1154,12 @@ class StableDiffusion: img = pipeline( prompt=None, prompt_embeds=conditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype), - prompt_attention_mask=conditional_embeds.attention_mask.to(self.device_torch, dtype=self.unet.dtype), - negative_prompt_embeds=unconditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype), - negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(self.device_torch, dtype=self.unet.dtype), + prompt_attention_mask=conditional_embeds.attention_mask.to(self.device_torch, + dtype=self.unet.dtype), + negative_prompt_embeds=unconditional_embeds.text_embeds.to(self.device_torch, + dtype=self.unet.dtype), + negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(self.device_torch, + dtype=self.unet.dtype), negative_prompt=None, # negative_prompt=gen_config.negative_prompt, height=gen_config.height, @@ -1166,9 +1175,12 @@ class StableDiffusion: img = pipeline( prompt=None, prompt_embeds=conditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype), - prompt_attention_mask=conditional_embeds.attention_mask.to(self.device_torch, dtype=self.unet.dtype), - negative_prompt_embeds=unconditional_embeds.text_embeds.to(self.device_torch, dtype=self.unet.dtype), - negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(self.device_torch, dtype=self.unet.dtype), + prompt_attention_mask=conditional_embeds.attention_mask.to(self.device_torch, + dtype=self.unet.dtype), + negative_prompt_embeds=unconditional_embeds.text_embeds.to(self.device_torch, + dtype=self.unet.dtype), + negative_prompt_attention_mask=unconditional_embeds.attention_mask.to(self.device_torch, + dtype=self.unet.dtype), negative_prompt=None, # negative_prompt=gen_config.negative_prompt, height=gen_config.height, @@ -1247,9 +1259,7 @@ class StableDiffusion: if self.model_config.assistant_lora_path is not None: print("Loading asistant lora") # unfortunately, not an easier way with peft - self.pipeline.load_lora_weights(self.model_config.assistant_lora_path, adapter_name="assistant_lora") - self.pipeline.fuse_lora(lora_scale=1.0) - self.pipeline.unload_lora_weights() + self.assistant_lora.is_active = True def get_latent_noise( self, @@ -1332,7 +1342,8 @@ class StableDiffusion: noisy_latents_chunks = [] for idx in range(original_samples.shape[0]): - noisy_latents = self.noise_scheduler.add_noise(original_samples_chunks[idx], noise_chunks[idx], timesteps_chunks[idx]) + noisy_latents = self.noise_scheduler.add_noise(original_samples_chunks[idx], noise_chunks[idx], + timesteps_chunks[idx]) noisy_latents_chunks.append(noisy_latents) noisy_latents = torch.cat(noisy_latents_chunks, dim=0) @@ -1392,7 +1403,6 @@ class StableDiffusion: else: timestep = timestep.repeat(latents.shape[0], 0) - # handle t2i adapters if 'down_intrablock_additional_residuals' in kwargs: # go through each item and concat if doing cfg and it doesnt have the same shape @@ -1561,7 +1571,6 @@ class StableDiffusion: height = h * VAE_SCALE_FACTOR width = w * VAE_SCALE_FACTOR - if self.pipeline.transformer.config.sample_size == 256: aspect_ratio_bin = ASPECT_RATIO_2048_BIN elif self.pipeline.transformer.config.sample_size == 128: @@ -1573,10 +1582,12 @@ class StableDiffusion: else: raise ValueError(f"Invalid sample size: {self.pipeline.transformer.config.sample_size}") orig_height, orig_width = height, width - height, width = self.pipeline.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin) + height, width = self.pipeline.image_processor.classify_height_width_bin(height, width, + ratios=aspect_ratio_bin) added_cond_kwargs = {"resolution": None, "aspect_ratio": None} - if self.unet.config.sample_size == 128 or (self.vae_scale_factor == 16 and self.unet.config.sample_size == 64): + if self.unet.config.sample_size == 128 or ( + self.vae_scale_factor == 16 and self.unet.config.sample_size == 64): resolution = torch.tensor([height, width]).repeat(batch_size, 1) aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size, 1) resolution = resolution.to(dtype=text_embeddings.text_embeds.dtype, device=self.device_torch) @@ -1641,7 +1652,8 @@ class StableDiffusion: # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing) # todo make sure this doesnt change timestep=timestep / 1000, # timestep is 1000 scale - encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, cast_dtype), # [1, 512, 4096] + encoder_hidden_states=text_embeddings.text_embeds.to(self.device_torch, cast_dtype), + # [1, 512, 4096] pooled_projections=text_embeddings.pooled_embeds.to(self.device_torch, cast_dtype), # [1, 768] txt_ids=txt_ids, # [1, 512, 3] img_ids=img_ids, # [1, 4096, 3] @@ -1705,7 +1717,7 @@ class StableDiffusion: with torch.no_grad(): # do cfg at the target rescale so we can match it target_pred_mean_std = noise_pred_uncond + rescale_cfg * ( - noise_pred_text - noise_pred_uncond + noise_pred_text - noise_pred_uncond ) target_mean = target_pred_mean_std.mean([1, 2, 3], keepdim=True).detach() target_std = target_pred_mean_std.std([1, 2, 3], keepdim=True).detach() @@ -1910,7 +1922,7 @@ class StableDiffusion: self.text_encoder, prompt, truncate=not long_prompts, - max_length=77, # todo set this higher when not transfer learning + max_length=77, # todo set this higher when not transfer learning dropout_prob=dropout_prob ) return PromptEmbeds( @@ -1957,16 +1969,19 @@ class StableDiffusion: for i in range(len(image_list)): image = image_list[i] if image.shape[1] % VAE_SCALE_FACTOR != 0 or image.shape[2] % VAE_SCALE_FACTOR != 0: - image_list[i] = Resize((image.shape[1] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR, image.shape[2] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR))(image) + image_list[i] = Resize((image.shape[1] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR, + image.shape[2] // VAE_SCALE_FACTOR * VAE_SCALE_FACTOR))(image) images = torch.stack(image_list) if isinstance(self.vae, AutoencoderTiny): latents = self.vae.encode(images, return_dict=False)[0] else: latents = self.vae.encode(images).latent_dist.sample() - # latents = self.vae.encode(images, return_dict=False)[0] shift = self.vae.config['shift_factor'] if self.vae.config['shift_factor'] is not None else 0 - latents = latents * (self.vae.config['scaling_factor'] - shift) + + # flux ref https://github.com/black-forest-labs/flux/blob/c23ae247225daba30fbd56058d247cc1b1fc20a3/src/flux/modules/autoencoder.py#L303 + # z = self.scale_factor * (z - self.shift_factor) + latents = self.vae.config['scaling_factor'] * (latents - shift) latents = latents.to(device, dtype=dtype) return latents @@ -2107,12 +2122,15 @@ class StableDiffusion: # train the guidance embedding if self.unet.config.guidance_embeds: transformer: FluxTransformer2DModel = self.unet - for name, param in transformer.time_text_embed.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"): + for name, param in transformer.time_text_embed.named_parameters(recurse=True, + prefix=f"{SD_PREFIX_UNET}"): named_params[name] = param - for name, param in self.unet.transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"): + for name, param in self.unet.transformer_blocks.named_parameters(recurse=True, + prefix=f"{SD_PREFIX_UNET}"): named_params[name] = param - for name, param in self.unet.single_transformer_blocks.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"): + for name, param in self.unet.single_transformer_blocks.named_parameters(recurse=True, + prefix=f"{SD_PREFIX_UNET}"): named_params[name] = param else: for name, param in self.unet.named_parameters(recurse=True, prefix=f"{SD_PREFIX_UNET}"):