diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 3d8fcde3..a4b1c2d2 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -275,16 +275,18 @@ class SDTrainer(BaseSDTrainProcess): if self.train_config.inverted_mask_prior and prior_pred is not None and has_mask: assert not self.train_config.train_turbo - # we need to make the noise prediction be a masked blending of noise and prior_pred - stretched_mask_multiplier = value_map( - mask_multiplier, - batch.file_items[0].dataset_config.mask_min_value, - 1.0, - 0.0, - 1.0 - ) + with torch.no_grad(): + # we need to make the noise prediction be a masked blending of noise and prior_pred + stretched_mask_multiplier = value_map( + mask_multiplier, + batch.file_items[0].dataset_config.mask_min_value, + 1.0, + 0.0, + 1.0 + ) + + prior_mask_multiplier = 1.0 - stretched_mask_multiplier - prior_mask_multiplier = 1.0 - stretched_mask_multiplier # target_mask_multiplier = mask_multiplier # mask_multiplier = 1.0 diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 8e91c825..9adf05d1 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -940,6 +940,11 @@ class BaseSDTrainProcess(BaseTrainProcess): batch.mask_tensor = double_up_tensor(batch.mask_tensor) batch.control_tensor = double_up_tensor(batch.control_tensor) + noisy_latent_multiplier = self.train_config.noisy_latent_multiplier + + if noisy_latent_multiplier != 1.0: + noisy_latents = noisy_latents * noisy_latent_multiplier + # remove grads for these noisy_latents.requires_grad = False noisy_latents = noisy_latents.detach() diff --git a/jobs/process/GenerateProcess.py b/jobs/process/GenerateProcess.py index f5fa22b0..95bb7427 100644 --- a/jobs/process/GenerateProcess.py +++ b/jobs/process/GenerateProcess.py @@ -1,7 +1,7 @@ import gc import os from collections import OrderedDict -from typing import ForwardRef, List +from typing import ForwardRef, List, Optional, Union import torch from safetensors.torch import save_file, load_file @@ -22,6 +22,7 @@ class GenerateConfig: self.sampler = kwargs.get('sampler', 'ddpm') self.width = kwargs.get('width', 512) self.height = kwargs.get('height', 512) + self.size_list: Union[List[int], None] = kwargs.get('size_list', None) self.neg = kwargs.get('neg', '') self.seed = kwargs.get('seed', -1) self.guidance_scale = kwargs.get('guidance_scale', 7) @@ -30,6 +31,7 @@ class GenerateConfig: self.neg_2 = kwargs.get('neg_2', None) self.prompts = kwargs.get('prompts', None) self.guidance_rescale = kwargs.get('guidance_rescale', 0.0) + self.compile = kwargs.get('compile', False) self.ext = kwargs.get('ext', 'png') self.prompt_file = kwargs.get('prompt_file', False) self.prompts_in_file = self.prompts @@ -93,17 +95,26 @@ class GenerateProcess(BaseProcess): self.sd.load_model() print("Compiling model...") - self.sd.unet = torch.compile(self.sd.unet, mode="reduce-overhead", fullgraph=True) + # self.sd.unet = torch.compile(self.sd.unet, mode="reduce-overhead", fullgraph=True) + if self.generate_config.compile: + self.sd.unet = torch.compile(self.sd.unet, mode="reduce-overhead") print(f"Generating {len(self.generate_config.prompts)} images") # build prompt image configs prompt_image_configs = [] for prompt in self.generate_config.prompts: + width = self.generate_config.width + height = self.generate_config.height + + if self.generate_config.size_list is not None: + # randomly select a size + width, height = random.choice(self.generate_config.size_list) + prompt_image_configs.append(GenerateImageConfig( prompt=prompt, prompt_2=self.generate_config.prompt_2, - width=self.generate_config.width, - height=self.generate_config.height, + width=width, + height=height, num_inference_steps=self.generate_config.sample_steps, guidance_scale=self.generate_config.guidance_scale, negative_prompt=self.generate_config.neg, diff --git a/testing/test_bucket_dataloader.py b/testing/test_bucket_dataloader.py index a223307a..eeac0d17 100644 --- a/testing/test_bucket_dataloader.py +++ b/testing/test_bucket_dataloader.py @@ -21,12 +21,14 @@ from toolkit.data_loader import AiToolkitDataset, get_dataloader_from_datasets, trigger_dataloader_setup_epoch from toolkit.config_modules import DatasetConfig import argparse +from tqdm import tqdm parser = argparse.ArgumentParser() parser.add_argument('dataset_folder', type=str, default='input') parser.add_argument('--epochs', type=int, default=1) + args = parser.parse_args() dataset_folder = args.dataset_folder @@ -40,27 +42,27 @@ batch_size = 1 dataset_config = DatasetConfig( dataset_path=dataset_folder, resolution=resolution, - caption_ext='json', + # caption_ext='json', default_caption='default', - clip_image_path='/mnt/Datasets/face_pairs2/control_clean', + # clip_image_path='/mnt/Datasets2/regs/yetibear_xl_v14/random_aspect/', buckets=True, bucket_tolerance=bucket_tolerance, - poi='person', - augmentations=[ - { - 'method': 'RandomBrightnessContrast', - 'brightness_limit': (-0.3, 0.3), - 'contrast_limit': (-0.3, 0.3), - 'brightness_by_max': False, - 'p': 1.0 - }, - { - 'method': 'HueSaturationValue', - 'hue_shift_limit': (-0, 0), - 'sat_shift_limit': (-40, 40), - 'val_shift_limit': (-40, 40), - 'p': 1.0 - }, + # poi='person', + # augmentations=[ + # { + # 'method': 'RandomBrightnessContrast', + # 'brightness_limit': (-0.3, 0.3), + # 'contrast_limit': (-0.3, 0.3), + # 'brightness_by_max': False, + # 'p': 1.0 + # }, + # { + # 'method': 'HueSaturationValue', + # 'hue_shift_limit': (-0, 0), + # 'sat_shift_limit': (-40, 40), + # 'val_shift_limit': (-40, 40), + # 'p': 1.0 + # }, # { # 'method': 'RGBShift', # 'r_shift_limit': (-20, 20), @@ -68,7 +70,7 @@ dataset_config = DatasetConfig( # 'b_shift_limit': (-20, 20), # 'p': 1.0 # }, - ] + # ] ) @@ -79,7 +81,7 @@ dataloader: DataLoader = get_dataloader_from_datasets([dataset_config], batch_si # run through an epoch ang check sizes dataloader_iterator = iter(dataloader) for epoch in range(args.epochs): - for batch in dataloader: + for batch in tqdm(dataloader): batch: 'DataLoaderBatchDTO' img_batch = batch.tensor @@ -98,7 +100,7 @@ for epoch in range(args.epochs): show_img(img) - time.sleep(1.0) + # time.sleep(0.1) # if not last epoch if epoch < args.epochs - 1: trigger_dataloader_setup_epoch(dataloader) diff --git a/toolkit/clip_vision_adapter.py b/toolkit/clip_vision_adapter.py index 83580fa9..636de814 100644 --- a/toolkit/clip_vision_adapter.py +++ b/toolkit/clip_vision_adapter.py @@ -41,20 +41,37 @@ class Embedder(nn.Module): self.layer_norm = nn.LayerNorm(input_dim) self.fc1 = nn.Linear(input_dim, mid_dim) self.gelu = nn.GELU() - self.fc2 = nn.Linear(mid_dim, output_dim * num_output_tokens) + # self.fc2 = nn.Linear(mid_dim, mid_dim) + self.fc2 = nn.Linear(mid_dim, mid_dim) - self.static_tokens = nn.Parameter(torch.randn(num_output_tokens, output_dim)) + self.fc2.weight.data.zero_() + + self.layer_norm2 = nn.LayerNorm(mid_dim) + self.fc3 = nn.Linear(mid_dim, mid_dim) + self.gelu2 = nn.GELU() + self.fc4 = nn.Linear(mid_dim, output_dim * num_output_tokens) + + # set the weights to 0 + self.fc3.weight.data.zero_() + self.fc4.weight.data.zero_() + + + # self.static_tokens = nn.Parameter(torch.zeros(num_output_tokens, output_dim)) + # self.scaler = nn.Parameter(torch.zeros(num_output_tokens, output_dim)) def forward(self, x): + if len(x.shape) == 2: + x = x.unsqueeze(1) x = self.layer_norm(x) x = self.fc1(x) x = self.gelu(x) x = self.fc2(x) - x = x.view(-1, self.num_output_tokens, self.output_dim) + x = self.layer_norm2(x) + x = self.fc3(x) + x = self.gelu2(x) + x = self.fc4(x) - # repeat the static tokens for each batch - static_tokens = torch.stack([self.static_tokens] * x.shape[0]) - x = static_tokens + x + x = x.view(-1, self.num_output_tokens, self.output_dim) return x @@ -89,6 +106,7 @@ class ClipVisionAdapter(torch.nn.Module): print(f"Adding {placeholder_tokens} tokens to tokenizer") print(f"Adding {self.config.num_tokens} tokens to tokenizer") + for text_encoder, tokenizer in zip(self.text_encoder_list, self.tokenizer_list): num_added_tokens = tokenizer.add_tokens(placeholder_tokens) if num_added_tokens != self.config.num_tokens: diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 54bfcb7f..4944c89a 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -246,6 +246,7 @@ class TrainConfig: self.adapter_assist_name_or_path: Optional[str] = kwargs.get('adapter_assist_name_or_path', None) self.noise_multiplier = kwargs.get('noise_multiplier', 1.0) self.img_multiplier = kwargs.get('img_multiplier', 1.0) + self.noisy_latent_multiplier = kwargs.get('noisy_latent_multiplier', 1.0) self.latent_multiplier = kwargs.get('latent_multiplier', 1.0) self.negative_prompt = kwargs.get('negative_prompt', None) self.max_negative_prompts = kwargs.get('max_negative_prompts', 1) diff --git a/toolkit/embedding.py b/toolkit/embedding.py index 31ac4ce2..94ba3f2f 100644 --- a/toolkit/embedding.py +++ b/toolkit/embedding.py @@ -86,18 +86,19 @@ class Embedding: self.orig_embeds_params = [x.get_input_embeddings().weight.data.clone() for x in self.text_encoder_list] def restore_embeddings(self): - # Let's make sure we don't update any embedding weights besides the newly added token - for text_encoder, tokenizer, orig_embeds, placeholder_token_ids in zip(self.text_encoder_list, - self.tokenizer_list, - self.orig_embeds_params, - self.placeholder_token_ids): - index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool) - index_no_updates[ - min(placeholder_token_ids): max(placeholder_token_ids) + 1] = False - with torch.no_grad(): + with torch.no_grad(): + # Let's make sure we don't update any embedding weights besides the newly added token + for text_encoder, tokenizer, orig_embeds, placeholder_token_ids in zip(self.text_encoder_list, + self.tokenizer_list, + self.orig_embeds_params, + self.placeholder_token_ids): + index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool) + index_no_updates[ min(placeholder_token_ids): max(placeholder_token_ids) + 1] = False text_encoder.get_input_embeddings().weight[ index_no_updates ] = orig_embeds[index_no_updates] + weight = text_encoder.get_input_embeddings().weight + pass def get_trainable_params(self): params = [] diff --git a/toolkit/ip_adapter.py b/toolkit/ip_adapter.py index 1fdbcb20..830bba8b 100644 --- a/toolkit/ip_adapter.py +++ b/toolkit/ip_adapter.py @@ -387,7 +387,7 @@ class IPAdapter(torch.nn.Module): cross_attn_dim = 4096 if is_pixart else sd.unet.config['cross_attention_dim'] image_proj_model = MLPProjModelClipFace( cross_attention_dim=cross_attn_dim, - id_embeddings_dim=1024, + id_embeddings_dim=self.image_encoder.config.projection_dim, num_tokens=self.config.num_tokens, # usually 4 ) elif adapter_config.type == 'ip+': @@ -486,7 +486,21 @@ class IPAdapter(torch.nn.Module): attn_processor_names = [] + blocks = [] + transformer_blocks = [] for name in attn_processor_keys: + name_split = name.split(".") + block_name = f"{name_split[0]}.{name_split[1]}" + transformer_idx = name_split.index("transformer_blocks") if "transformer_blocks" in name_split else -1 + if transformer_idx >= 0: + transformer_name = ".".join(name_split[:2]) + transformer_name += "." + ".".join(name_split[transformer_idx:transformer_idx + 2]) + if transformer_name not in transformer_blocks: + transformer_blocks.append(transformer_name) + + + if block_name not in blocks: + blocks.append(block_name) cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") or name.endswith("attn1") else \ sd.unet.config['cross_attention_dim'] if name.startswith("mid_block"): diff --git a/toolkit/models/ilora.py b/toolkit/models/ilora.py index 745a3687..a41e8150 100644 --- a/toolkit/models/ilora.py +++ b/toolkit/models/ilora.py @@ -15,6 +15,30 @@ if TYPE_CHECKING: from toolkit.stable_diffusion_model import StableDiffusion +class ILoRAProjModule(torch.nn.Module): + def __init__(self, num_modules=1, dim=4, embeddings_dim=512): + super().__init__() + + self.num_modules = num_modules + self.num_dim = dim + self.norm = torch.nn.LayerNorm(embeddings_dim) + + self.proj = torch.nn.Sequential( + torch.nn.Linear(embeddings_dim, embeddings_dim * 2), + torch.nn.GELU(), + torch.nn.Linear(embeddings_dim * 2, num_modules * dim), + ) + # Initialize the last linear layer weights near zero + torch.nn.init.uniform_(self.proj[2].weight, a=-0.01, b=0.01) + torch.nn.init.zeros_(self.proj[2].bias) + + def forward(self, x): + x = self.norm(x) + x = self.proj(x) + x = x.reshape(-1, self.num_modules, self.num_dim) + return x + + class InstantLoRAMidModule(torch.nn.Module): def __init__( self, @@ -54,7 +78,7 @@ class InstantLoRAMidModule(torch.nn.Module): raise e # apply tanh to limit values to -1 to 1 # scaler = torch.tanh(scaler) - return x * (scaler + 1.0) + return x * scaler class InstantLoRAModule(torch.nn.Module): @@ -92,20 +116,25 @@ class InstantLoRAModule(torch.nn.Module): # num_blocks=1, # ) # heads = 20 - heads = 12 - dim = 1280 - output_dim = self.dim - self.resampler = Resampler( - dim=dim, - depth=4, - dim_head=64, - heads=heads, - num_queries=len(lora_modules), - embedding_dim=self.vision_hidden_size, - max_seq_len=self.vision_tokens, - output_dim=output_dim, - ff_mult=4 - ) + # heads = 12 + # dim = 1280 + # output_dim = self.dim + self.proj_module = ILoRAProjModule( + num_modules=len(lora_modules), + dim=self.dim, + embeddings_dim=self.vision_hidden_size, + ) + # self.resampler = Resampler( + # dim=dim, + # depth=4, + # dim_head=64, + # heads=heads, + # num_queries=len(lora_modules), + # embedding_dim=self.vision_hidden_size, + # max_seq_len=self.vision_tokens, + # output_dim=output_dim, + # ff_mult=4 + # ) for idx, lora_module in enumerate(lora_modules): # add a new mid module that will take the original forward and add a vector to it @@ -128,6 +157,6 @@ class InstantLoRAModule(torch.nn.Module): # expand token rank if only rank 2 if len(img_embeds.shape) == 2: img_embeds = img_embeds.unsqueeze(1) - img_embeds = self.resampler(img_embeds) + img_embeds = self.proj_module(img_embeds) self.img_embeds = img_embeds diff --git a/toolkit/train_tools.py b/toolkit/train_tools.py index e80c0649..1bae37b5 100644 --- a/toolkit/train_tools.py +++ b/toolkit/train_tools.py @@ -390,7 +390,7 @@ def sample_images( # https://www.crosslabs.org//blog/diffusion-with-offset-noise def apply_noise_offset(noise, noise_offset): - if noise_offset is None or noise_offset < 0.0000001: + if noise_offset is None or (noise_offset < 0.000001 and noise_offset > -0.000001): return noise noise = noise + noise_offset * torch.randn((noise.shape[0], noise.shape[1], 1, 1), device=noise.device) return noise