From 88b3fbae379e4d9201d98acdd0739bd52d4a25dc Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Fri, 25 Apr 2025 13:44:38 -0600 Subject: [PATCH] Various experiments and minor bug fixes for edge cases --- jobs/process/BaseSDTrainProcess.py | 31 ++++-- toolkit/config_modules.py | 1 + toolkit/models/base_model.py | 4 +- toolkit/models/ilora2.py | 160 ++++++++++------------------- toolkit/models/llm_adapter.py | 4 +- toolkit/stable_diffusion_model.py | 5 +- toolkit/util/blended_blur_noise.py | 84 +++++++++++++++ toolkit/util/shuffle.py | 3 + 8 files changed, 170 insertions(+), 122 deletions(-) create mode 100644 toolkit/util/blended_blur_noise.py diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 08b6760f..83ed8242 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -69,6 +69,7 @@ import transformers import diffusers import hashlib +from toolkit.util.blended_blur_noise import get_blended_blur_noise from toolkit.util.get_model import get_model_class def flush(): @@ -903,7 +904,14 @@ class BaseSDTrainProcess(BaseTrainProcess): return noise - def get_noise(self, latents, batch_size, dtype=torch.float32, batch: 'DataLoaderBatchDTO' = None): + def get_noise( + self, + latents, + batch_size, + dtype=torch.float32, + batch: 'DataLoaderBatchDTO' = None, + timestep=None, + ): if self.train_config.optimal_noise_pairing_samples > 1: noise = self.get_optimal_noise(latents, dtype=dtype) elif self.train_config.force_consistent_noise: @@ -933,12 +941,11 @@ class BaseSDTrainProcess(BaseTrainProcess): # add to noise noise += noise_shift - - # standardize the noise - # shouldnt be needed? - # std = noise.std(dim=(2, 3), keepdim=True) - # normalizer = 1 / (std + 1e-6) - # noise = noise * normalizer + + if self.train_config.blended_blur_noise: + noise = get_blended_blur_noise( + latents, noise, timestep + ) return noise @@ -1193,7 +1200,7 @@ class BaseSDTrainProcess(BaseTrainProcess): timesteps = torch.stack(timesteps, dim=0) # get noise - noise = self.get_noise(latents, batch_size, dtype=dtype, batch=batch) + noise = self.get_noise(latents, batch_size, dtype=dtype, batch=batch, timestep=timesteps) # add dynamic noise offset. Dynamic noise is offsetting the noise to the same channelwise mean as the latents # this will negate any noise offsets @@ -1924,10 +1931,14 @@ class BaseSDTrainProcess(BaseTrainProcess): start_step_num = self.step_num did_first_flush = False + flush_next = False for step in range(start_step_num, self.train_config.steps): if self.train_config.do_paramiter_swapping: self.optimizer.optimizer.swap_paramiters() self.timer.start('train_loop') + if flush_next: + flush() + flush_next = False if self.train_config.do_random_cfg: self.train_config.do_cfg = True self.train_config.cfg_scale = value_map(random.random(), 0, 1, 1.0, self.train_config.max_cfg_scale) @@ -2089,6 +2100,10 @@ class BaseSDTrainProcess(BaseTrainProcess): print_acc(f"\nSaving at step {self.step_num}") self.save(self.step_num) self.ensure_params_requires_grad() + # clear any grads + optimizer.zero_grad() + flush() + flush_next = True if self.progress_bar is not None: self.progress_bar.unpause() diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 3268aa13..0f9ef88e 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -460,6 +460,7 @@ class TrainConfig: # forces same noise for the same image at a given size. self.force_consistent_noise = kwargs.get('force_consistent_noise', False) + self.blended_blur_noise = kwargs.get('blended_blur_noise', False) ModelArch = Literal['sd1', 'sd2', 'sd3', 'sdxl', 'pixart', 'pixart_sigma', 'auraflow', 'flux', 'flex2', 'lumina2', 'vega', 'ssd', 'wan21'] diff --git a/toolkit/models/base_model.py b/toolkit/models/base_model.py index b371604f..6982e506 100644 --- a/toolkit/models/base_model.py +++ b/toolkit/models/base_model.py @@ -1160,12 +1160,12 @@ class BaseModel: if self.model_config.ignore_if_contains is not None: # remove params that contain the ignore_if_contains from named params for key in list(named_params.keys()): - if any([s in key for s in self.model_config.ignore_if_contains]): + if any([s in f"transformer.{key}" for s in self.model_config.ignore_if_contains]): del named_params[key] if self.model_config.only_if_contains is not None: # remove params that do not contain the only_if_contains from named params for key in list(named_params.keys()): - if not any([s in key for s in self.model_config.only_if_contains]): + if not any([s in f"transformer.{key}" for s in self.model_config.only_if_contains]): del named_params[key] if refiner: diff --git a/toolkit/models/ilora2.py b/toolkit/models/ilora2.py index 2aba5eae..886d263c 100644 --- a/toolkit/models/ilora2.py +++ b/toolkit/models/ilora2.py @@ -1,7 +1,6 @@ import math import weakref -from toolkit.config_modules import AdapterConfig import torch import torch.nn as nn from typing import TYPE_CHECKING, List, Dict, Any @@ -35,7 +34,6 @@ class MLP(nn.Module): x = x + residual return x - class LoRAGenerator(torch.nn.Module): def __init__( self, @@ -60,8 +58,7 @@ class LoRAGenerator(torch.nn.Module): self.lin_in = nn.Linear(input_size, hidden_size) self.mlp_blocks = nn.Sequential(*[ - MLP(hidden_size, hidden_size, hidden_size, dropout=dropout, use_residual=True) for _ in - range(num_mlp_layers) + MLP(hidden_size, hidden_size, hidden_size, dropout=dropout, use_residual=True) for _ in range(num_mlp_layers) ]) self.head = nn.Linear(hidden_size, head_size, bias=False) self.norm = nn.LayerNorm(head_size) @@ -128,22 +125,15 @@ class InstantLoRAMidModule(torch.nn.Module): self.lora_module_ref = weakref.ref(lora_module) self.instant_lora_module_ref = weakref.ref(instant_lora_module) - self.do_up = instant_lora_module.config.ilora_up - self.do_down = instant_lora_module.config.ilora_down - self.do_mid = instant_lora_module.config.ilora_mid - - self.down_dim = self.down_shape[1] if self.do_down else 0 - self.mid_dim = self.up_shape[1] if self.do_mid else 0 - self.out_dim = self.up_shape[0] if self.do_up else 0 - self.embed = None def down_forward(self, x, *args, **kwargs): - if not self.do_down: - return self.lora_module_ref().lora_down.orig_forward(x, *args, **kwargs) # get the embed self.embed = self.instant_lora_module_ref().img_embeds[self.index] - down_weight = self.embed[:, :self.down_dim] + if x.dtype != self.embed.dtype: + x = x.to(self.embed.dtype) + down_size = math.prod(self.down_shape) + down_weight = self.embed[:, :down_size] batch_size = x.shape[0] @@ -151,72 +141,7 @@ class InstantLoRAMidModule(torch.nn.Module): if down_weight.shape[0] * 2 == batch_size: down_weight = torch.cat([down_weight] * 2, dim=0) - try: - if len(x.shape) == 4: - # conv - down_weight = down_weight.view(batch_size, -1, 1, 1) - if x.shape[1] != down_weight.shape[1]: - raise ValueError(f"Down weight shape not understood: {down_weight.shape} {x.shape}") - elif len(x.shape) == 2: - down_weight = down_weight.view(batch_size, -1) - if x.shape[1] != down_weight.shape[1]: - raise ValueError(f"Down weight shape not understood: {down_weight.shape} {x.shape}") - else: - down_weight = down_weight.view(batch_size, 1, -1) - if x.shape[2] != down_weight.shape[2]: - raise ValueError(f"Down weight shape not understood: {down_weight.shape} {x.shape}") - x = x * down_weight - x = self.lora_module_ref().lora_down.orig_forward(x, *args, **kwargs) - except Exception as e: - print(e) - raise ValueError(f"Down weight shape not understood: {down_weight.shape} {x.shape}") - - return x - - def up_forward(self, x, *args, **kwargs): - # do mid here - x = self.mid_forward(x, *args, **kwargs) - if not self.do_up: - return self.lora_module_ref().lora_up.orig_forward(x, *args, **kwargs) - # get the embed - self.embed = self.instant_lora_module_ref().img_embeds[self.index] - up_weight = self.embed[:, -self.out_dim:] - - batch_size = x.shape[0] - - # unconditional - if up_weight.shape[0] * 2 == batch_size: - up_weight = torch.cat([up_weight] * 2, dim=0) - - try: - if len(x.shape) == 4: - # conv - up_weight = up_weight.view(batch_size, -1, 1, 1) - elif len(x.shape) == 2: - up_weight = up_weight.view(batch_size, -1) - else: - up_weight = up_weight.view(batch_size, 1, -1) - x = self.lora_module_ref().lora_up.orig_forward(x, *args, **kwargs) - x = x * up_weight - except Exception as e: - print(e) - raise ValueError(f"Up weight shape not understood: {up_weight.shape} {x.shape}") - - return x - - def mid_forward(self, x, *args, **kwargs): - if not self.do_mid: - return self.lora_module_ref().lora_down.orig_forward(x, *args, **kwargs) - batch_size = x.shape[0] - # get the embed - self.embed = self.instant_lora_module_ref().img_embeds[self.index] - mid_weight = self.embed[:, self.down_dim:self.down_dim + self.mid_dim * self.mid_dim] - - # unconditional - if mid_weight.shape[0] * 2 == batch_size: - mid_weight = torch.cat([mid_weight] * 2, dim=0) - - weight_chunks = torch.chunk(mid_weight, batch_size, dim=0) + weight_chunks = torch.chunk(down_weight, batch_size, dim=0) x_chunks = torch.chunk(x, batch_size, dim=0) x_out = [] @@ -224,11 +149,43 @@ class InstantLoRAMidModule(torch.nn.Module): weight_chunk = weight_chunks[i] x_chunk = x_chunks[i] # reshape - if len(x_chunk.shape) == 4: - # conv - weight_chunk = weight_chunk.view(self.mid_dim, self.mid_dim, 1, 1) + weight_chunk = weight_chunk.view(self.down_shape) + # check if is conv or linear + if len(weight_chunk.shape) == 4: + org_module = self.lora_module_ref().orig_module_ref() + stride = org_module.stride + padding = org_module.padding + x_chunk = nn.functional.conv2d(x_chunk, weight_chunk, padding=padding, stride=stride) else: - weight_chunk = weight_chunk.view(self.mid_dim, self.mid_dim) + # run a simple linear layer with the down weight + x_chunk = x_chunk @ weight_chunk.T + x_out.append(x_chunk) + x = torch.cat(x_out, dim=0) + return x + + + def up_forward(self, x, *args, **kwargs): + self.embed = self.instant_lora_module_ref().img_embeds[self.index] + if x.dtype != self.embed.dtype: + x = x.to(self.embed.dtype) + up_size = math.prod(self.up_shape) + up_weight = self.embed[:, -up_size:] + + batch_size = x.shape[0] + + # unconditional + if up_weight.shape[0] * 2 == batch_size: + up_weight = torch.cat([up_weight] * 2, dim=0) + + weight_chunks = torch.chunk(up_weight, batch_size, dim=0) + x_chunks = torch.chunk(x, batch_size, dim=0) + + x_out = [] + for i in range(batch_size): + weight_chunk = weight_chunks[i] + x_chunk = x_chunks[i] + # reshape + weight_chunk = weight_chunk.view(self.up_shape) # check if is conv or linear if len(weight_chunk.shape) == 4: padding = 0 @@ -243,15 +200,17 @@ class InstantLoRAMidModule(torch.nn.Module): return x + + class InstantLoRAModule(torch.nn.Module): def __init__( self, vision_hidden_size: int, vision_tokens: int, head_dim: int, - num_heads: int, # number of heads in the resampler + num_heads: int, # number of heads in the resampler sd: 'StableDiffusion', - config: AdapterConfig + config=None ): super(InstantLoRAModule, self).__init__() # self.linear = torch.nn.Linear(2, 1) @@ -262,8 +221,6 @@ class InstantLoRAModule(torch.nn.Module): self.head_dim = head_dim self.num_heads = num_heads - self.config: AdapterConfig = config - # stores the projection vector. Grabbed by modules self.img_embeds: List[torch.Tensor] = None @@ -286,21 +243,11 @@ class InstantLoRAModule(torch.nn.Module): self.weight_mapping.append([lora_module.lora_name, [down_shape, up_shape]]) - # - # module_size = math.prod(down_shape) + math.prod(up_shape) - - # conv weight shape is (out_channels, in_channels, kernel_size, kernel_size) - # linear weight shape is (out_features, in_features) - - # just doing in dim and out dim - in_dim = down_shape[1] if self.config.ilora_down else 0 - mid_dim = down_shape[0] * down_shape[0] if self.config.ilora_mid else 0 - out_dim = up_shape[0] if self.config.ilora_up else 0 - module_size = in_dim + mid_dim + out_dim - + module_size = math.prod(down_shape) + math.prod(up_shape) output_size += module_size self.embed_lengths.append(module_size) + # add a new mid module that will take the original forward and add a vector to it # this will be used to add the vector to the original forward instant_module = InstantLoRAMidModule( @@ -314,11 +261,10 @@ class InstantLoRAModule(torch.nn.Module): self.ilora_modules.append(instant_module) # replace the LoRA forwards - lora_module.lora_down.orig_forward = lora_module.lora_down.forward lora_module.lora_down.forward = instant_module.down_forward - lora_module.lora_up.orig_forward = lora_module.lora_up.forward lora_module.lora_up.forward = instant_module.up_forward + self.output_size = output_size number_formatted_output_size = "{:,}".format(output_size) @@ -378,6 +324,7 @@ class InstantLoRAModule(torch.nn.Module): # print("No keymap found. Using default names") # return + def forward(self, img_embeds): # expand token rank if only rank 2 if len(img_embeds.shape) == 2: @@ -394,9 +341,10 @@ class InstantLoRAModule(torch.nn.Module): # get all the slices start = 0 for length in self.embed_lengths: - self.img_embeds.append(img_embeds[:, start:start + length]) + self.img_embeds.append(img_embeds[:, start:start+length]) start += length + def get_additional_save_metadata(self) -> Dict[str, Any]: # save the weight mapping return { @@ -406,7 +354,5 @@ class InstantLoRAModule(torch.nn.Module): "head_dim": self.head_dim, "vision_tokens": self.vision_tokens, "output_size": self.output_size, - "do_up": self.config.ilora_up, - "do_mid": self.config.ilora_mid, - "do_down": self.config.ilora_down, } + diff --git a/toolkit/models/llm_adapter.py b/toolkit/models/llm_adapter.py index d3fd612d..9c098f90 100644 --- a/toolkit/models/llm_adapter.py +++ b/toolkit/models/llm_adapter.py @@ -65,8 +65,8 @@ class LLMAdapter(torch.nn.Module): if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token - # self.system_prompt = "" - self.system_prompt = "You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts. " + self.system_prompt = "" + # self.system_prompt = "You are an assistant designed to generate superior images with the superior degree of image-text alignment based on textual prompts or user prompts. " # determine length of system prompt sys_prompt_tokenized = tokenizer( diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index c8c01aa2..66c8c19c 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -1401,8 +1401,7 @@ class StableDiffusion: conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image) self.adapter(conditional_clip_embeds) - if self.adapter is not None and isinstance(self.adapter, CustomAdapter) \ - and gen_config.adapter_image_path is not None: + if self.adapter is not None and isinstance(self.adapter, CustomAdapter): # handle condition the prompts gen_config.prompt = self.adapter.condition_prompt( gen_config.prompt, @@ -1456,7 +1455,7 @@ class StableDiffusion: conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds, is_unconditional=False) unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds, is_unconditional=True) - if self.adapter is not None and isinstance(self.adapter, CustomAdapter) and validation_image is not None: + if self.adapter is not None and isinstance(self.adapter, CustomAdapter): conditional_embeds = self.adapter.condition_encoded_embeds( tensors_0_1=validation_image, prompt_embeds=conditional_embeds, diff --git a/toolkit/util/blended_blur_noise.py b/toolkit/util/blended_blur_noise.py new file mode 100644 index 00000000..86f46ddb --- /dev/null +++ b/toolkit/util/blended_blur_noise.py @@ -0,0 +1,84 @@ +import torch + +cached_multipier = None + +def get_multiplier(timesteps, num_timesteps=1000): + global cached_multipier + if cached_multipier is None: + # creates a bell curve + x = torch.arange(num_timesteps, dtype=torch.float32) + y = torch.exp(-2 * ((x - num_timesteps / 2) / num_timesteps) ** 2) + + # Shift minimum to 0 + y_shifted = y - y.min() + + # Scale to make mean 1 + cached_multipier = y_shifted * (num_timesteps / y_shifted.sum()) + + scale_list = [] + # get the idx multiplier for each timestep + for i in range(timesteps.shape[0]): + idx = min(int(timesteps[i].item()) - 1, 0) + scale_list.append(cached_multipier[idx:idx + 1]) + + scales = torch.cat(scale_list, dim=0) + + batch_multiplier = scales.view(-1, 1, 1, 1) + + return batch_multiplier + + +def get_blended_blur_noise(latents, noise, timestep): + latent_chunks = torch.chunk(latents, latents.shape[0], dim=0) + + # timestep is 1000 to 0 + # timestep = timestep.to(latents.device, dtype=latents.dtype) + + # scale it so timestep 1000 is 0 and 0 is 2 + # blur_strength = value_map(timestep, 1000, 0, 0, 1.0) + # blur_strength = timestep / 500.0 + # blur_strength = blur_strength.view(-1, 1, 1, 1) + + # scale to 2.0 max + # blur_strength = get_multiplier(timestep).to( + # latents.device, dtype=latents.dtype + # ) * 2.0 + + # blur_strength = 2.0 + + blurred_latent_chunks = [] + for i in range(len(latent_chunks)): + latent_chunk = latent_chunks[i] + # get two random scalers 0.1 to 0.9 + # scaler1 = random.uniform(0.2, 0.8) + scaler1 = 0.25 + scaler2 = scaler1 + + # shrink latents by 1/4 and bring them back for blurring using interpolation + blur_latents = torch.nn.functional.interpolate( + latent_chunk, + size=(int(latents.shape[2] * scaler1), int(latents.shape[3] * scaler2)), + mode='bilinear', + align_corners=False + ) + blur_latents = torch.nn.functional.interpolate( + blur_latents, + size=(latents.shape[2], latents.shape[3]), + mode='bilinear', + align_corners=False + ) + # only the difference of the blur from ground truth + blur_latents = blur_latents - latent_chunk + blurred_latent_chunks.append(blur_latents) + + blur_latents = torch.cat(blurred_latent_chunks, dim=0) + + + # make random strength along batch 0 to 1 + blur_strength = torch.rand((latents.shape[0], 1, 1, 1), device=latents.device, dtype=latents.dtype) * 2 + + blur_latents = blur_latents * blur_strength + + noise = noise + blur_latents + return noise + \ No newline at end of file diff --git a/toolkit/util/shuffle.py b/toolkit/util/shuffle.py index 7d735f9d..940dee17 100644 --- a/toolkit/util/shuffle.py +++ b/toolkit/util/shuffle.py @@ -41,6 +41,9 @@ def shuffle_tensor_along_axis(tensor, axis=0, seed=None): # Apply the shuffle shuffled_tensor = tensor[slices] + + except Exception as e: + raise RuntimeError(f"Error during shuffling: {e}") finally: # Restore original random states