Save entire pixart model again

This commit is contained in:
Jaret Burkett
2024-07-07 07:56:48 -06:00
parent cab8a1c7b8
commit 045e4a6e15
5 changed files with 177 additions and 32 deletions

View File

@@ -2,7 +2,7 @@ import os
import torch import torch
from transformers import T5EncoderModel, T5Tokenizer from transformers import T5EncoderModel, T5Tokenizer
from diffusers import StableDiffusionPipeline, UNet2DConditionModel, PixArtSigmaPipeline, Transformer2DModel from diffusers import StableDiffusionPipeline, UNet2DConditionModel, PixArtSigmaPipeline, Transformer2DModel, PixArtTransformer2DModel
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file
from collections import OrderedDict from collections import OrderedDict
import json import json
@@ -13,8 +13,8 @@ import json
# output_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_t5xl_raw" # output_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_t5xl_raw"
model_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024-MS" model_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024-MS"
te_path = "google/flan-t5-base" te_path = "google/flan-t5-base"
te_aug_path = "/home/jaret/Dev/models/tmp/pixart_sigma_t5base_000204000.safetensors" te_aug_path = "/home/jaret/Dev/models/tmp/pixart_sigma_t5base_000227500.safetensors"
output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024-MS_t5base_raw" output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5base_raw"
print("Loading te adapter") print("Loading te adapter")
@@ -28,11 +28,13 @@ is_pixart = "pixart" in model_path.lower()
pipeline_class = StableDiffusionPipeline pipeline_class = StableDiffusionPipeline
transformer = PixArtTransformer2DModel.from_pretrained('PixArt-alpha/PixArt-Sigma-XL-2-512-MS', subfolder='transformer', torch_dtype=torch.float16)
if is_pixart: if is_pixart:
pipeline_class = PixArtSigmaPipeline pipeline_class = PixArtSigmaPipeline
if is_diffusers: if is_diffusers:
sd = pipeline_class.from_pretrained(model_path, torch_dtype=torch.float16) sd = pipeline_class.from_pretrained(model_path, transformer=transformer, torch_dtype=torch.float16)
else: else:
sd = pipeline_class.from_single_file(model_path, torch_dtype=torch.float16) sd = pipeline_class.from_single_file(model_path, torch_dtype=torch.float16)
@@ -113,6 +115,11 @@ for name in attn_processor_keys:
if weight_idx > 1000: if weight_idx > 1000:
raise ValueError("Could not find the next weight") raise ValueError("Could not find the next weight")
orig_weight_shape_k = list(unet_sd[layer_name + ".to_k.weight"].shape)
new_weight_shape_k = list(te_aug_sd[te_aug_name + ".weight"].shape)
orig_weight_shape_v = list(unet_sd[layer_name + ".to_v.weight"].shape)
new_weight_shape_v = list(te_aug_sd[te_aug_name.replace('to_k', 'to_v') + ".weight"].shape)
unet_sd[layer_name + ".to_k.weight"] = te_aug_sd[te_aug_name + ".weight"] unet_sd[layer_name + ".to_k.weight"] = te_aug_sd[te_aug_name + ".weight"]
unet_sd[layer_name + ".to_v.weight"] = te_aug_sd[te_aug_name.replace('to_k', 'to_v') + ".weight"] unet_sd[layer_name + ".to_v.weight"] = te_aug_sd[te_aug_name.replace('to_k', 'to_v') + ".weight"]
@@ -120,6 +127,14 @@ for name in attn_processor_keys:
new_cross_attn_dim = unet_sd[layer_name + ".to_k.weight"].shape[1] new_cross_attn_dim = unet_sd[layer_name + ".to_k.weight"].shape[1]
if is_pixart:
# copy the caption_projection weight
del unet_sd['caption_projection.linear_1.bias']
del unet_sd['caption_projection.linear_1.weight']
del unet_sd['caption_projection.linear_2.bias']
del unet_sd['caption_projection.linear_2.weight']
print("Saving unmodified model") print("Saving unmodified model")
sd = sd.to("cpu", torch.float16) sd = sd.to("cpu", torch.float16)
sd.save_pretrained( sd.save_pretrained(
@@ -150,7 +165,7 @@ with open(os.path.join(unet_folder, "config.json"), 'r') as f:
config['cross_attention_dim'] = new_cross_attn_dim config['cross_attention_dim'] = new_cross_attn_dim
if is_pixart: if is_pixart:
config['caption_channels'] = te.config.d_model config['caption_channels'] = None
# save it # save it
with open(os.path.join(unet_folder, "config.json"), 'w') as f: with open(os.path.join(unet_folder, "config.json"), 'w') as f:

View File

@@ -1,31 +1,18 @@
import os
import torch import torch
from transformers import T5EncoderModel, T5Tokenizer
from diffusers import StableDiffusionPipeline, UNet2DConditionModel, PixArtSigmaPipeline, Transformer2DModel
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file
from collections import OrderedDict from collections import OrderedDict
import json
# model_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_v01_000527000" model_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model_orig.safetensors"
# te_path = "google/flan-t5-xl" output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model.safetensors"
# te_aug_path = "/mnt/Train/out/ip_adapter/t5xx_sd15_v1/t5xx_sd15_v1_000032000.safetensors"
# output_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_t5xl_raw"
model_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model.safetensors"
output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model_tiny.safetensors"
te_aug_path = "/home/jaret/Dev/models/tmp/pixart_sigma_t5base_000204000.safetensors"
state_dict = load_file(model_path) state_dict = load_file(model_path)
meta = OrderedDict() meta = OrderedDict()
meta["format"] = "pt" meta["format"] = "pt"
# has 28 blocks
# keep block 0 and 27
new_state_dict = {} new_state_dict = {}
# move non blocks over # Move non-blocks over
for key, value in state_dict.items(): for key, value in state_dict.items():
if not key.startswith("transformer_blocks."): if not key.startswith("transformer_blocks."):
new_state_dict[key] = value new_state_dict[key] = value
@@ -71,6 +58,5 @@ save_file(new_state_dict, output_path, metadata=meta)
new_param_count = sum([v.numel() for v in new_state_dict.values()]) new_param_count = sum([v.numel() for v in new_state_dict.values()])
old_param_count = sum([v.numel() for v in state_dict.values()]) old_param_count = sum([v.numel() for v in state_dict.values()])
# porint comma formatted
print(f"Old param count: {old_param_count:,}") print(f"Old param count: {old_param_count:,}")
print(f"New param count: {new_param_count:,}") print(f"New param count: {new_param_count:,}")

81
testing/shrink_pixart2.py Normal file
View File

@@ -0,0 +1,81 @@
import torch
from safetensors.torch import load_file, save_file
from collections import OrderedDict
model_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model_orig.safetensors"
output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model.safetensors"
state_dict = load_file(model_path)
meta = OrderedDict()
meta["format"] = "pt"
new_state_dict = {}
# Move non-blocks over
for key, value in state_dict.items():
if not key.startswith("transformer_blocks."):
new_state_dict[key] = value
block_names = ['transformer_blocks.{idx}.attn1.to_k.bias', 'transformer_blocks.{idx}.attn1.to_k.weight',
'transformer_blocks.{idx}.attn1.to_out.0.bias', 'transformer_blocks.{idx}.attn1.to_out.0.weight',
'transformer_blocks.{idx}.attn1.to_q.bias', 'transformer_blocks.{idx}.attn1.to_q.weight',
'transformer_blocks.{idx}.attn1.to_v.bias', 'transformer_blocks.{idx}.attn1.to_v.weight',
'transformer_blocks.{idx}.attn2.to_k.bias', 'transformer_blocks.{idx}.attn2.to_k.weight',
'transformer_blocks.{idx}.attn2.to_out.0.bias', 'transformer_blocks.{idx}.attn2.to_out.0.weight',
'transformer_blocks.{idx}.attn2.to_q.bias', 'transformer_blocks.{idx}.attn2.to_q.weight',
'transformer_blocks.{idx}.attn2.to_v.bias', 'transformer_blocks.{idx}.attn2.to_v.weight',
'transformer_blocks.{idx}.ff.net.0.proj.bias', 'transformer_blocks.{idx}.ff.net.0.proj.weight',
'transformer_blocks.{idx}.ff.net.2.bias', 'transformer_blocks.{idx}.ff.net.2.weight',
'transformer_blocks.{idx}.scale_shift_table']
# Blocks to keep
# keep_blocks = [0, 1, 2, 6, 10, 14, 18, 22, 26, 27]
keep_blocks = [0, 1, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 27]
def weighted_merge(kept_block, removed_block, weight):
return kept_block * (1 - weight) + removed_block * weight
# First, copy all kept blocks to new_state_dict
for i, old_idx in enumerate(keep_blocks):
for name in block_names:
old_key = name.format(idx=old_idx)
new_key = name.format(idx=i)
new_state_dict[new_key] = state_dict[old_key].clone()
# Then, merge information from removed blocks
for i in range(28):
if i not in keep_blocks:
# Find the nearest kept blocks
prev_kept = max([b for b in keep_blocks if b < i])
next_kept = min([b for b in keep_blocks if b > i])
# Calculate the weight based on position
weight = (i - prev_kept) / (next_kept - prev_kept)
for name in block_names:
removed_key = name.format(idx=i)
prev_new_key = name.format(idx=keep_blocks.index(prev_kept))
next_new_key = name.format(idx=keep_blocks.index(next_kept))
# Weighted merge for previous kept block
new_state_dict[prev_new_key] = weighted_merge(new_state_dict[prev_new_key], state_dict[removed_key], weight)
# Weighted merge for next kept block
new_state_dict[next_new_key] = weighted_merge(new_state_dict[next_new_key], state_dict[removed_key],
1 - weight)
# Convert to fp16 and move to CPU
for key, value in new_state_dict.items():
new_state_dict[key] = value.to(torch.float16).cpu()
# Save the new state dict
save_file(new_state_dict, output_path, metadata=meta)
new_param_count = sum([v.numel() for v in new_state_dict.values()])
old_param_count = sum([v.numel() for v in state_dict.values()])
print(f"Old param count: {old_param_count:,}")
print(f"New param count: {new_param_count:,}")

View File

@@ -0,0 +1,63 @@
import torch
from safetensors.torch import load_file, save_file
from collections import OrderedDict
model_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_tiny/transformer/diffusion_pytorch_model.orig.safetensors"
output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_tiny/transformer/diffusion_pytorch_model.safetensors"
state_dict = load_file(model_path)
meta = OrderedDict()
meta["format"] = "pt"
new_state_dict = {}
# Move non-blocks over
for key, value in state_dict.items():
if not key.startswith("transformer_blocks."):
new_state_dict[key] = value
block_names = ['transformer_blocks.{idx}.attn1.to_k.bias', 'transformer_blocks.{idx}.attn1.to_k.weight',
'transformer_blocks.{idx}.attn1.to_out.0.bias', 'transformer_blocks.{idx}.attn1.to_out.0.weight',
'transformer_blocks.{idx}.attn1.to_q.bias', 'transformer_blocks.{idx}.attn1.to_q.weight',
'transformer_blocks.{idx}.attn1.to_v.bias', 'transformer_blocks.{idx}.attn1.to_v.weight',
'transformer_blocks.{idx}.attn2.to_k.bias', 'transformer_blocks.{idx}.attn2.to_k.weight',
'transformer_blocks.{idx}.attn2.to_out.0.bias', 'transformer_blocks.{idx}.attn2.to_out.0.weight',
'transformer_blocks.{idx}.attn2.to_q.bias', 'transformer_blocks.{idx}.attn2.to_q.weight',
'transformer_blocks.{idx}.attn2.to_v.bias', 'transformer_blocks.{idx}.attn2.to_v.weight',
'transformer_blocks.{idx}.ff.net.0.proj.bias', 'transformer_blocks.{idx}.ff.net.0.proj.weight',
'transformer_blocks.{idx}.ff.net.2.bias', 'transformer_blocks.{idx}.ff.net.2.weight',
'transformer_blocks.{idx}.scale_shift_table']
# New block idx 0, 1, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 27
current_idx = 0
for i in range(28):
if i not in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]:
# todo merge in with previous block
for name in block_names:
continue
# try:
# new_state_dict_key = name.format(idx=current_idx - 1)
# old_state_dict_key = name.format(idx=i)
# new_state_dict[new_state_dict_key] = (new_state_dict[new_state_dict_key] * 0.5) + (state_dict[old_state_dict_key] * 0.5)
# except KeyError:
# raise KeyError(f"KeyError: {name.format(idx=current_idx)}")
else:
for name in block_names:
new_state_dict[name.format(idx=current_idx)] = state_dict[name.format(idx=i)]
current_idx += 1
# make sure they are all fp16 and on cpu
for key, value in new_state_dict.items():
new_state_dict[key] = value.to(torch.float16).cpu()
# save the new state dict
save_file(new_state_dict, output_path, metadata=meta)
new_param_count = sum([v.numel() for v in new_state_dict.values()])
old_param_count = sum([v.numel() for v in state_dict.values()])
print(f"Old param count: {old_param_count:,}")
print(f"New param count: {new_param_count:,}")

View File

@@ -1775,16 +1775,16 @@ class StableDiffusion:
# saving in diffusers format # saving in diffusers format
if not output_file.endswith('.safetensors'): if not output_file.endswith('.safetensors'):
# diffusers # diffusers
if self.is_pixart: # if self.is_pixart:
self.unet.save_pretrained( # self.unet.save_pretrained(
save_directory=output_file, # save_directory=output_file,
safe_serialization=True, # safe_serialization=True,
) # )
else: # else:
self.pipeline.save_pretrained( self.pipeline.save_pretrained(
save_directory=output_file, save_directory=output_file,
safe_serialization=True, safe_serialization=True,
) )
# save out meta config # save out meta config
meta_path = os.path.join(output_file, 'aitk_meta.yaml') meta_path = os.path.join(output_file, 'aitk_meta.yaml')
with open(meta_path, 'w') as f: with open(meta_path, 'w') as f: