diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 33f162e7..3c85de62 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -270,6 +270,10 @@ class SDTrainer(BaseSDTrainProcess): noise_pred = noise_pred * self.train_config.pred_scaler target = None + + if self.train_config.target_noise_multiplier != 1.0: + noise = noise * self.train_config.target_noise_multiplier + if self.train_config.correct_pred_norm or (self.train_config.inverted_mask_prior and prior_pred is not None and has_mask): if self.train_config.correct_pred_norm and not is_reg: with torch.no_grad(): diff --git a/testing/merge_in_text_encoder_adapter.py b/testing/merge_in_text_encoder_adapter.py index e4a5cff9..08d5c02e 100644 --- a/testing/merge_in_text_encoder_adapter.py +++ b/testing/merge_in_text_encoder_adapter.py @@ -11,10 +11,10 @@ import json # te_path = "google/flan-t5-xl" # te_aug_path = "/mnt/Train/out/ip_adapter/t5xx_sd15_v1/t5xx_sd15_v1_000032000.safetensors" # output_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_t5xl_raw" -model_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024-MS" -te_path = "google/flan-t5-base" -te_aug_path = "/home/jaret/Dev/models/tmp/pixart_sigma_t5base_000227500.safetensors" -output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5base_raw" +model_path = "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS" +te_path = "google/flan-t5-large" +te_aug_path = "/home/jaret/Dev/models/tmp/pixart_sigma_t5l_000034000.safetensors" +output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5large_raw" print("Loading te adapter") diff --git a/testing/shrink_pixart_sm.py b/testing/shrink_pixart_sm.py index c1e2f608..8cea07bf 100644 --- a/testing/shrink_pixart_sm.py +++ b/testing/shrink_pixart_sm.py @@ -2,62 +2,83 @@ import torch from safetensors.torch import load_file, save_file from collections import OrderedDict -model_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_tiny/transformer/diffusion_pytorch_model.orig.safetensors" -output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_tiny/transformer/diffusion_pytorch_model.safetensors" - -state_dict = load_file(model_path) - meta = OrderedDict() -meta["format"] = "pt" +meta['format'] = "pt" +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def reduce_weight(weight, target_size): + weight = weight.to(device, torch.float32) + original_shape = weight.shape + flattened = weight.view(-1, original_shape[-1]) + + if flattened.shape[1] <= target_size: + return weight + + U, S, V = torch.svd(flattened) + reduced = torch.mm(U[:, :target_size], torch.diag(S[:target_size])) + + if reduced.shape[1] < target_size: + padding = torch.zeros(reduced.shape[0], target_size - reduced.shape[1], device=device) + reduced = torch.cat((reduced, padding), dim=1) + + return reduced.view(original_shape[:-1] + (target_size,)) + + +def reduce_bias(bias, target_size): + bias = bias.to(device, torch.float32) + original_size = bias.shape[0] + + if original_size <= target_size: + return torch.nn.functional.pad(bias, (0, target_size - original_size)) + else: + return bias.view(-1, original_size // target_size).mean(dim=1)[:target_size] + + +# Load your original state dict +state_dict = load_file( + "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5large_raw/transformer/diffusion_pytorch_model.orig.safetensors") + +# Create a new state dict for the reduced model new_state_dict = {} -# Move non-blocks over +source_hidden_size = 1152 +target_hidden_size = 1024 + for key, value in state_dict.items(): - if not key.startswith("transformer_blocks."): - new_state_dict[key] = value + value = value.to(device, torch.float32) + if 'weight' in key or 'scale_shift_table' in key: + if value.shape[0] == source_hidden_size: + value = value[:target_hidden_size] + elif value.shape[0] == source_hidden_size * 4: + value = value[:target_hidden_size * 4] + elif value.shape[0] == source_hidden_size * 6: + value = value[:target_hidden_size * 6] -block_names = ['transformer_blocks.{idx}.attn1.to_k.bias', 'transformer_blocks.{idx}.attn1.to_k.weight', - 'transformer_blocks.{idx}.attn1.to_out.0.bias', 'transformer_blocks.{idx}.attn1.to_out.0.weight', - 'transformer_blocks.{idx}.attn1.to_q.bias', 'transformer_blocks.{idx}.attn1.to_q.weight', - 'transformer_blocks.{idx}.attn1.to_v.bias', 'transformer_blocks.{idx}.attn1.to_v.weight', - 'transformer_blocks.{idx}.attn2.to_k.bias', 'transformer_blocks.{idx}.attn2.to_k.weight', - 'transformer_blocks.{idx}.attn2.to_out.0.bias', 'transformer_blocks.{idx}.attn2.to_out.0.weight', - 'transformer_blocks.{idx}.attn2.to_q.bias', 'transformer_blocks.{idx}.attn2.to_q.weight', - 'transformer_blocks.{idx}.attn2.to_v.bias', 'transformer_blocks.{idx}.attn2.to_v.weight', - 'transformer_blocks.{idx}.ff.net.0.proj.bias', 'transformer_blocks.{idx}.ff.net.0.proj.weight', - 'transformer_blocks.{idx}.ff.net.2.bias', 'transformer_blocks.{idx}.ff.net.2.weight', - 'transformer_blocks.{idx}.scale_shift_table'] + if len(value.shape) > 1 and value.shape[ + 1] == source_hidden_size and 'attn2.to_k.weight' not in key and 'attn2.to_v.weight' not in key: + value = value[:, :target_hidden_size] + elif len(value.shape) > 1 and value.shape[1] == source_hidden_size * 4: + value = value[:, :target_hidden_size * 4] -# New block idx 0, 1, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 27 + elif 'bias' in key: + if value.shape[0] == source_hidden_size: + value = value[:target_hidden_size] + elif value.shape[0] == source_hidden_size * 4: + value = value[:target_hidden_size * 4] + elif value.shape[0] == source_hidden_size * 6: + value = value[:target_hidden_size * 6] -current_idx = 0 -for i in range(28): - if i not in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]: - # todo merge in with previous block - for name in block_names: - continue - # try: - # new_state_dict_key = name.format(idx=current_idx - 1) - # old_state_dict_key = name.format(idx=i) - # new_state_dict[new_state_dict_key] = (new_state_dict[new_state_dict_key] * 0.5) + (state_dict[old_state_dict_key] * 0.5) - # except KeyError: - # raise KeyError(f"KeyError: {name.format(idx=current_idx)}") - else: - for name in block_names: - new_state_dict[name.format(idx=current_idx)] = state_dict[name.format(idx=i)] - current_idx += 1 + new_state_dict[key] = value - -# make sure they are all fp16 and on cpu +# Move all to CPU and convert to float16 for key, value in new_state_dict.items(): - new_state_dict[key] = value.to(torch.float16).cpu() + new_state_dict[key] = value.cpu().to(torch.float16) -# save the new state dict -save_file(new_state_dict, output_path, metadata=meta) +# Save the new state dict +save_file(new_state_dict, + "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5large_raw/transformer/diffusion_pytorch_model.safetensors", + metadata=meta) -new_param_count = sum([v.numel() for v in new_state_dict.values()]) -old_param_count = sum([v.numel() for v in state_dict.values()]) - -print(f"Old param count: {old_param_count:,}") -print(f"New param count: {new_param_count:,}") \ No newline at end of file +print("Done!") diff --git a/testing/shrink_pixart_sm2.py b/testing/shrink_pixart_sm2.py new file mode 100644 index 00000000..dd3304df --- /dev/null +++ b/testing/shrink_pixart_sm2.py @@ -0,0 +1,110 @@ +import torch +from safetensors.torch import load_file, save_file +from collections import OrderedDict + +meta = OrderedDict() +meta['format'] = "pt" + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def reduce_weight(weight, target_size): + weight = weight.to(device, torch.float32) + original_shape = weight.shape + + if len(original_shape) == 1: + # For 1D tensors, simply truncate + return weight[:target_size] + + if original_shape[0] <= target_size: + return weight + + # Reshape the tensor to 2D + flattened = weight.reshape(original_shape[0], -1) + + # Perform SVD + U, S, V = torch.svd(flattened) + + # Reduce the dimensions + reduced = torch.mm(U[:target_size, :], torch.diag(S)).mm(V.t()) + + # Reshape back to the original shape with reduced first dimension + new_shape = (target_size,) + original_shape[1:] + return reduced.reshape(new_shape) + + +def reduce_bias(bias, target_size): + bias = bias.to(device, torch.float32) + return bias[:target_size] + + +# Load your original state dict +state_dict = load_file( + "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5large_raw/transformer/diffusion_pytorch_model.orig.safetensors") + +# Create a new state dict for the reduced model +new_state_dict = {} + +for key, value in state_dict.items(): + value = value.to(device, torch.float32) + + if 'weight' in key or 'scale_shift_table' in key: + if value.shape[0] == 1152: + if len(value.shape) == 4: + orig_shape = value.shape + output_shape = (512, orig_shape[1], orig_shape[2], orig_shape[3]) # reshape to (1152, -1) + # reshape to (1152, -1) + value = value.view(value.shape[0], -1) + value = reduce_weight(value, 512) + value = value.view(output_shape) + else: + # value = reduce_weight(value.t(), 576).t().contiguous() + value = reduce_weight(value, 512) + pass + elif value.shape[0] == 4608: + if len(value.shape) == 4: + orig_shape = value.shape + output_shape = (2048, orig_shape[1], orig_shape[2], orig_shape[3]) + value = value.view(value.shape[0], -1) + value = reduce_weight(value, 2048) + value = value.view(output_shape) + else: + value = reduce_weight(value, 2048) + elif value.shape[0] == 6912: + if len(value.shape) == 4: + orig_shape = value.shape + output_shape = (3072, orig_shape[1], orig_shape[2], orig_shape[3]) + value = value.view(value.shape[0], -1) + value = reduce_weight(value, 3072) + value = value.view(output_shape) + else: + value = reduce_weight(value, 3072) + + if len(value.shape) > 1 and value.shape[ + 1] == 1152 and 'attn2.to_k.weight' not in key and 'attn2.to_v.weight' not in key: + value = reduce_weight(value.t(), 512).t().contiguous() # Transpose before and after reduction + pass + elif len(value.shape) > 1 and value.shape[1] == 4608: + value = reduce_weight(value.t(), 2048).t().contiguous() # Transpose before and after reduction + pass + + elif 'bias' in key: + if value.shape[0] == 1152: + value = reduce_bias(value, 512) + elif value.shape[0] == 4608: + value = reduce_bias(value, 2048) + elif value.shape[0] == 6912: + value = reduce_bias(value, 3072) + + new_state_dict[key] = value + +# Move all to CPU and convert to float16 +for key, value in new_state_dict.items(): + new_state_dict[key] = value.cpu().to(torch.float16) + +# Save the new state dict +save_file(new_state_dict, + "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5large_raw/transformer/diffusion_pytorch_model.safetensors", + metadata=meta) + +print("Done!") \ No newline at end of file diff --git a/testing/shrink_pixart_sm3.py b/testing/shrink_pixart_sm3.py new file mode 100644 index 00000000..b8756aec --- /dev/null +++ b/testing/shrink_pixart_sm3.py @@ -0,0 +1,100 @@ +import torch +from safetensors.torch import load_file, save_file +from collections import OrderedDict + +meta = OrderedDict() +meta['format'] = "pt" + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def reduce_weight(weight, target_size): + weight = weight.to(device, torch.float32) + # resize so target_size is the first dimension + tmp_weight = weight.view(1, 1, weight.shape[0], weight.shape[1]) + + # use interpolate to resize the tensor + new_weight = torch.nn.functional.interpolate(tmp_weight, size=(target_size, weight.shape[1]), mode='bicubic', align_corners=True) + + # reshape back to original shape + return new_weight.view(target_size, weight.shape[1]) + + +def reduce_bias(bias, target_size): + bias = bias.view(1, 1, bias.shape[0], 1) + + new_bias = torch.nn.functional.interpolate(bias, size=(target_size, 1), mode='bicubic', align_corners=True) + + return new_bias.view(target_size) + + +# Load your original state dict +state_dict = load_file( + "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5large_raw/transformer/diffusion_pytorch_model.orig.safetensors") + +# Create a new state dict for the reduced model +new_state_dict = {} + +for key, value in state_dict.items(): + value = value.to(device, torch.float32) + + if 'weight' in key or 'scale_shift_table' in key: + if value.shape[0] == 1152: + if len(value.shape) == 4: + orig_shape = value.shape + output_shape = (512, orig_shape[1], orig_shape[2], orig_shape[3]) # reshape to (1152, -1) + # reshape to (1152, -1) + value = value.view(value.shape[0], -1) + value = reduce_weight(value, 512) + value = value.view(output_shape) + else: + # value = reduce_weight(value.t(), 576).t().contiguous() + value = reduce_weight(value, 512) + pass + elif value.shape[0] == 4608: + if len(value.shape) == 4: + orig_shape = value.shape + output_shape = (2048, orig_shape[1], orig_shape[2], orig_shape[3]) + value = value.view(value.shape[0], -1) + value = reduce_weight(value, 2048) + value = value.view(output_shape) + else: + value = reduce_weight(value, 2048) + elif value.shape[0] == 6912: + if len(value.shape) == 4: + orig_shape = value.shape + output_shape = (3072, orig_shape[1], orig_shape[2], orig_shape[3]) + value = value.view(value.shape[0], -1) + value = reduce_weight(value, 3072) + value = value.view(output_shape) + else: + value = reduce_weight(value, 3072) + + if len(value.shape) > 1 and value.shape[ + 1] == 1152 and 'attn2.to_k.weight' not in key and 'attn2.to_v.weight' not in key: + value = reduce_weight(value.t(), 512).t().contiguous() # Transpose before and after reduction + pass + elif len(value.shape) > 1 and value.shape[1] == 4608: + value = reduce_weight(value.t(), 2048).t().contiguous() # Transpose before and after reduction + pass + + elif 'bias' in key: + if value.shape[0] == 1152: + value = reduce_bias(value, 512) + elif value.shape[0] == 4608: + value = reduce_bias(value, 2048) + elif value.shape[0] == 6912: + value = reduce_bias(value, 3072) + + new_state_dict[key] = value + +# Move all to CPU and convert to float16 +for key, value in new_state_dict.items(): + new_state_dict[key] = value.cpu().to(torch.float16) + +# Save the new state dict +save_file(new_state_dict, + "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5large_raw/transformer/diffusion_pytorch_model.safetensors", + metadata=meta) + +print("Done!") \ No newline at end of file diff --git a/toolkit/clip_vision_adapter.py b/toolkit/clip_vision_adapter.py index 636de814..4ccc920c 100644 --- a/toolkit/clip_vision_adapter.py +++ b/toolkit/clip_vision_adapter.py @@ -166,7 +166,7 @@ class ClipVisionAdapter(torch.nn.Module): if hasattr(self.image_encoder.config, 'hidden_sizes'): embedding_dim = self.image_encoder.config.hidden_sizes[-1] else: - embedding_dim = self.image_encoder.config.hidden_size + embedding_dim = self.image_encoder.config.target_hidden_size if self.config.clip_layer == 'image_embeds': in_tokens = 1 @@ -308,15 +308,15 @@ class ClipVisionAdapter(torch.nn.Module): # add it to the text encoder self.set_vec(image_prompt_embeds[0], text_encoder_idx=0) elif len(self.text_encoder_list) == 2: - if self.text_encoder_list[0].config.hidden_size + self.text_encoder_list[1].config.hidden_size != \ + if self.text_encoder_list[0].config.target_hidden_size + self.text_encoder_list[1].config.target_hidden_size != \ image_prompt_embeds.shape[2]: raise ValueError("Something went wrong. The embeddings do not match the text encoder sizes") # sdxl variants # image_prompt_embeds = 2048 # te1 = 768 # te2 = 1280 - te1_embeds = image_prompt_embeds[:, :, :self.text_encoder_list[0].config.hidden_size] - te2_embeds = image_prompt_embeds[:, :, self.text_encoder_list[0].config.hidden_size:] + te1_embeds = image_prompt_embeds[:, :, :self.text_encoder_list[0].config.target_hidden_size] + te2_embeds = image_prompt_embeds[:, :, self.text_encoder_list[0].config.target_hidden_size:] self.set_vec(te1_embeds[0], text_encoder_idx=0) self.set_vec(te2_embeds[0], text_encoder_idx=1) else: diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 4b744adf..5edc3fb7 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -251,6 +251,7 @@ class TrainConfig: self.adapter_assist_name_or_path: Optional[str] = kwargs.get('adapter_assist_name_or_path', None) self.adapter_assist_type: Optional[str] = kwargs.get('adapter_assist_type', 't2i') # t2i, control_net self.noise_multiplier = kwargs.get('noise_multiplier', 1.0) + self.target_noise_multiplier = kwargs.get('target_noise_multiplier', 1.0) self.img_multiplier = kwargs.get('img_multiplier', 1.0) self.noisy_latent_multiplier = kwargs.get('noisy_latent_multiplier', 1.0) self.latent_multiplier = kwargs.get('latent_multiplier', 1.0) diff --git a/toolkit/ip_adapter.py b/toolkit/ip_adapter.py index 9104aac8..56909655 100644 --- a/toolkit/ip_adapter.py +++ b/toolkit/ip_adapter.py @@ -394,7 +394,7 @@ class IPAdapter(torch.nn.Module): elif adapter_config.type == 'ip+': heads = 12 if not sd.is_xl else 20 dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280 - embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch.startswith( + embedding_dim = self.image_encoder.config.target_hidden_size if not self.config.image_encoder_arch.startswith( 'convnext') else \ self.image_encoder.config.hidden_sizes[-1] @@ -436,7 +436,7 @@ class IPAdapter(torch.nn.Module): if hasattr(self.image_encoder.config, 'hidden_sizes'): embedding_dim = self.image_encoder.config.hidden_sizes[-1] else: - embedding_dim = self.image_encoder.config.hidden_size + embedding_dim = self.image_encoder.config.target_hidden_size image_encoder_state_dict = self.image_encoder.state_dict() # max_seq_len = CLIP tokens + CLS token diff --git a/toolkit/models/te_adapter.py b/toolkit/models/te_adapter.py index 6da63a8d..bc182b70 100644 --- a/toolkit/models/te_adapter.py +++ b/toolkit/models/te_adapter.py @@ -246,7 +246,7 @@ class TEAdapter(torch.nn.Module): if self.adapter_ref().config.text_encoder_arch == "t5": self.token_size = self.te_ref().config.d_model else: - self.token_size = self.te_ref().config.hidden_size + self.token_size = self.te_ref().config.target_hidden_size # add text projection if is sdxl self.text_projection = None