diff --git a/testing/merge_in_text_encoder_adapter.py b/testing/merge_in_text_encoder_adapter.py index 9fda8a90..e4a5cff9 100644 --- a/testing/merge_in_text_encoder_adapter.py +++ b/testing/merge_in_text_encoder_adapter.py @@ -2,7 +2,7 @@ import os import torch from transformers import T5EncoderModel, T5Tokenizer -from diffusers import StableDiffusionPipeline, UNet2DConditionModel, PixArtSigmaPipeline, Transformer2DModel +from diffusers import StableDiffusionPipeline, UNet2DConditionModel, PixArtSigmaPipeline, Transformer2DModel, PixArtTransformer2DModel from safetensors.torch import load_file, save_file from collections import OrderedDict import json @@ -13,8 +13,8 @@ import json # output_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_t5xl_raw" model_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024-MS" te_path = "google/flan-t5-base" -te_aug_path = "/home/jaret/Dev/models/tmp/pixart_sigma_t5base_000204000.safetensors" -output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024-MS_t5base_raw" +te_aug_path = "/home/jaret/Dev/models/tmp/pixart_sigma_t5base_000227500.safetensors" +output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5base_raw" print("Loading te adapter") @@ -28,11 +28,13 @@ is_pixart = "pixart" in model_path.lower() pipeline_class = StableDiffusionPipeline +transformer = PixArtTransformer2DModel.from_pretrained('PixArt-alpha/PixArt-Sigma-XL-2-512-MS', subfolder='transformer', torch_dtype=torch.float16) + if is_pixart: pipeline_class = PixArtSigmaPipeline if is_diffusers: - sd = pipeline_class.from_pretrained(model_path, torch_dtype=torch.float16) + sd = pipeline_class.from_pretrained(model_path, transformer=transformer, torch_dtype=torch.float16) else: sd = pipeline_class.from_single_file(model_path, torch_dtype=torch.float16) @@ -113,6 +115,11 @@ for name in attn_processor_keys: if weight_idx > 1000: raise ValueError("Could not find the next weight") + orig_weight_shape_k = list(unet_sd[layer_name + ".to_k.weight"].shape) + new_weight_shape_k = list(te_aug_sd[te_aug_name + ".weight"].shape) + orig_weight_shape_v = list(unet_sd[layer_name + ".to_v.weight"].shape) + new_weight_shape_v = list(te_aug_sd[te_aug_name.replace('to_k', 'to_v') + ".weight"].shape) + unet_sd[layer_name + ".to_k.weight"] = te_aug_sd[te_aug_name + ".weight"] unet_sd[layer_name + ".to_v.weight"] = te_aug_sd[te_aug_name.replace('to_k', 'to_v') + ".weight"] @@ -120,6 +127,14 @@ for name in attn_processor_keys: new_cross_attn_dim = unet_sd[layer_name + ".to_k.weight"].shape[1] + +if is_pixart: + # copy the caption_projection weight + del unet_sd['caption_projection.linear_1.bias'] + del unet_sd['caption_projection.linear_1.weight'] + del unet_sd['caption_projection.linear_2.bias'] + del unet_sd['caption_projection.linear_2.weight'] + print("Saving unmodified model") sd = sd.to("cpu", torch.float16) sd.save_pretrained( @@ -150,7 +165,7 @@ with open(os.path.join(unet_folder, "config.json"), 'r') as f: config['cross_attention_dim'] = new_cross_attn_dim if is_pixart: - config['caption_channels'] = te.config.d_model + config['caption_channels'] = None # save it with open(os.path.join(unet_folder, "config.json"), 'w') as f: diff --git a/testing/shrink_pixart.py b/testing/shrink_pixart.py index 1cac2a53..ad27b1a0 100644 --- a/testing/shrink_pixart.py +++ b/testing/shrink_pixart.py @@ -1,31 +1,18 @@ -import os - import torch -from transformers import T5EncoderModel, T5Tokenizer -from diffusers import StableDiffusionPipeline, UNet2DConditionModel, PixArtSigmaPipeline, Transformer2DModel from safetensors.torch import load_file, save_file from collections import OrderedDict -import json -# model_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_v01_000527000" -# te_path = "google/flan-t5-xl" -# te_aug_path = "/mnt/Train/out/ip_adapter/t5xx_sd15_v1/t5xx_sd15_v1_000032000.safetensors" -# output_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_t5xl_raw" -model_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model.safetensors" -output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model_tiny.safetensors" -te_aug_path = "/home/jaret/Dev/models/tmp/pixart_sigma_t5base_000204000.safetensors" +model_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model_orig.safetensors" +output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model.safetensors" state_dict = load_file(model_path) meta = OrderedDict() meta["format"] = "pt" -# has 28 blocks -# keep block 0 and 27 - new_state_dict = {} -# move non blocks over +# Move non-blocks over for key, value in state_dict.items(): if not key.startswith("transformer_blocks."): new_state_dict[key] = value @@ -71,6 +58,5 @@ save_file(new_state_dict, output_path, metadata=meta) new_param_count = sum([v.numel() for v in new_state_dict.values()]) old_param_count = sum([v.numel() for v in state_dict.values()]) -# porint comma formatted print(f"Old param count: {old_param_count:,}") print(f"New param count: {new_param_count:,}") \ No newline at end of file diff --git a/testing/shrink_pixart2.py b/testing/shrink_pixart2.py new file mode 100644 index 00000000..f8c30cf8 --- /dev/null +++ b/testing/shrink_pixart2.py @@ -0,0 +1,81 @@ +import torch +from safetensors.torch import load_file, save_file +from collections import OrderedDict + +model_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model_orig.safetensors" +output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model.safetensors" + +state_dict = load_file(model_path) + +meta = OrderedDict() +meta["format"] = "pt" + +new_state_dict = {} + +# Move non-blocks over +for key, value in state_dict.items(): + if not key.startswith("transformer_blocks."): + new_state_dict[key] = value + +block_names = ['transformer_blocks.{idx}.attn1.to_k.bias', 'transformer_blocks.{idx}.attn1.to_k.weight', + 'transformer_blocks.{idx}.attn1.to_out.0.bias', 'transformer_blocks.{idx}.attn1.to_out.0.weight', + 'transformer_blocks.{idx}.attn1.to_q.bias', 'transformer_blocks.{idx}.attn1.to_q.weight', + 'transformer_blocks.{idx}.attn1.to_v.bias', 'transformer_blocks.{idx}.attn1.to_v.weight', + 'transformer_blocks.{idx}.attn2.to_k.bias', 'transformer_blocks.{idx}.attn2.to_k.weight', + 'transformer_blocks.{idx}.attn2.to_out.0.bias', 'transformer_blocks.{idx}.attn2.to_out.0.weight', + 'transformer_blocks.{idx}.attn2.to_q.bias', 'transformer_blocks.{idx}.attn2.to_q.weight', + 'transformer_blocks.{idx}.attn2.to_v.bias', 'transformer_blocks.{idx}.attn2.to_v.weight', + 'transformer_blocks.{idx}.ff.net.0.proj.bias', 'transformer_blocks.{idx}.ff.net.0.proj.weight', + 'transformer_blocks.{idx}.ff.net.2.bias', 'transformer_blocks.{idx}.ff.net.2.weight', + 'transformer_blocks.{idx}.scale_shift_table'] + +# Blocks to keep +# keep_blocks = [0, 1, 2, 6, 10, 14, 18, 22, 26, 27] +keep_blocks = [0, 1, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 27] + + +def weighted_merge(kept_block, removed_block, weight): + return kept_block * (1 - weight) + removed_block * weight + + +# First, copy all kept blocks to new_state_dict +for i, old_idx in enumerate(keep_blocks): + for name in block_names: + old_key = name.format(idx=old_idx) + new_key = name.format(idx=i) + new_state_dict[new_key] = state_dict[old_key].clone() + +# Then, merge information from removed blocks +for i in range(28): + if i not in keep_blocks: + # Find the nearest kept blocks + prev_kept = max([b for b in keep_blocks if b < i]) + next_kept = min([b for b in keep_blocks if b > i]) + + # Calculate the weight based on position + weight = (i - prev_kept) / (next_kept - prev_kept) + + for name in block_names: + removed_key = name.format(idx=i) + prev_new_key = name.format(idx=keep_blocks.index(prev_kept)) + next_new_key = name.format(idx=keep_blocks.index(next_kept)) + + # Weighted merge for previous kept block + new_state_dict[prev_new_key] = weighted_merge(new_state_dict[prev_new_key], state_dict[removed_key], weight) + + # Weighted merge for next kept block + new_state_dict[next_new_key] = weighted_merge(new_state_dict[next_new_key], state_dict[removed_key], + 1 - weight) + +# Convert to fp16 and move to CPU +for key, value in new_state_dict.items(): + new_state_dict[key] = value.to(torch.float16).cpu() + +# Save the new state dict +save_file(new_state_dict, output_path, metadata=meta) + +new_param_count = sum([v.numel() for v in new_state_dict.values()]) +old_param_count = sum([v.numel() for v in state_dict.values()]) + +print(f"Old param count: {old_param_count:,}") +print(f"New param count: {new_param_count:,}") \ No newline at end of file diff --git a/testing/shrink_pixart_sm.py b/testing/shrink_pixart_sm.py new file mode 100644 index 00000000..c1e2f608 --- /dev/null +++ b/testing/shrink_pixart_sm.py @@ -0,0 +1,63 @@ +import torch +from safetensors.torch import load_file, save_file +from collections import OrderedDict + +model_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_tiny/transformer/diffusion_pytorch_model.orig.safetensors" +output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_tiny/transformer/diffusion_pytorch_model.safetensors" + +state_dict = load_file(model_path) + +meta = OrderedDict() +meta["format"] = "pt" + +new_state_dict = {} + +# Move non-blocks over +for key, value in state_dict.items(): + if not key.startswith("transformer_blocks."): + new_state_dict[key] = value + +block_names = ['transformer_blocks.{idx}.attn1.to_k.bias', 'transformer_blocks.{idx}.attn1.to_k.weight', + 'transformer_blocks.{idx}.attn1.to_out.0.bias', 'transformer_blocks.{idx}.attn1.to_out.0.weight', + 'transformer_blocks.{idx}.attn1.to_q.bias', 'transformer_blocks.{idx}.attn1.to_q.weight', + 'transformer_blocks.{idx}.attn1.to_v.bias', 'transformer_blocks.{idx}.attn1.to_v.weight', + 'transformer_blocks.{idx}.attn2.to_k.bias', 'transformer_blocks.{idx}.attn2.to_k.weight', + 'transformer_blocks.{idx}.attn2.to_out.0.bias', 'transformer_blocks.{idx}.attn2.to_out.0.weight', + 'transformer_blocks.{idx}.attn2.to_q.bias', 'transformer_blocks.{idx}.attn2.to_q.weight', + 'transformer_blocks.{idx}.attn2.to_v.bias', 'transformer_blocks.{idx}.attn2.to_v.weight', + 'transformer_blocks.{idx}.ff.net.0.proj.bias', 'transformer_blocks.{idx}.ff.net.0.proj.weight', + 'transformer_blocks.{idx}.ff.net.2.bias', 'transformer_blocks.{idx}.ff.net.2.weight', + 'transformer_blocks.{idx}.scale_shift_table'] + +# New block idx 0, 1, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 27 + +current_idx = 0 +for i in range(28): + if i not in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]: + # todo merge in with previous block + for name in block_names: + continue + # try: + # new_state_dict_key = name.format(idx=current_idx - 1) + # old_state_dict_key = name.format(idx=i) + # new_state_dict[new_state_dict_key] = (new_state_dict[new_state_dict_key] * 0.5) + (state_dict[old_state_dict_key] * 0.5) + # except KeyError: + # raise KeyError(f"KeyError: {name.format(idx=current_idx)}") + else: + for name in block_names: + new_state_dict[name.format(idx=current_idx)] = state_dict[name.format(idx=i)] + current_idx += 1 + + +# make sure they are all fp16 and on cpu +for key, value in new_state_dict.items(): + new_state_dict[key] = value.to(torch.float16).cpu() + +# save the new state dict +save_file(new_state_dict, output_path, metadata=meta) + +new_param_count = sum([v.numel() for v in new_state_dict.values()]) +old_param_count = sum([v.numel() for v in state_dict.values()]) + +print(f"Old param count: {old_param_count:,}") +print(f"New param count: {new_param_count:,}") \ No newline at end of file diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index ac552e3e..08880a82 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -1775,16 +1775,16 @@ class StableDiffusion: # saving in diffusers format if not output_file.endswith('.safetensors'): # diffusers - if self.is_pixart: - self.unet.save_pretrained( - save_directory=output_file, - safe_serialization=True, - ) - else: - self.pipeline.save_pretrained( - save_directory=output_file, - safe_serialization=True, - ) + # if self.is_pixart: + # self.unet.save_pretrained( + # save_directory=output_file, + # safe_serialization=True, + # ) + # else: + self.pipeline.save_pretrained( + save_directory=output_file, + safe_serialization=True, + ) # save out meta config meta_path = os.path.join(output_file, 'aitk_meta.yaml') with open(meta_path, 'w') as f: