Save entire pixart model again

This commit is contained in:
Jaret Burkett
2024-07-07 07:56:48 -06:00
parent cab8a1c7b8
commit 045e4a6e15
5 changed files with 177 additions and 32 deletions

View File

@@ -2,7 +2,7 @@ import os
import torch
from transformers import T5EncoderModel, T5Tokenizer
from diffusers import StableDiffusionPipeline, UNet2DConditionModel, PixArtSigmaPipeline, Transformer2DModel
from diffusers import StableDiffusionPipeline, UNet2DConditionModel, PixArtSigmaPipeline, Transformer2DModel, PixArtTransformer2DModel
from safetensors.torch import load_file, save_file
from collections import OrderedDict
import json
@@ -13,8 +13,8 @@ import json
# output_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_t5xl_raw"
model_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024-MS"
te_path = "google/flan-t5-base"
te_aug_path = "/home/jaret/Dev/models/tmp/pixart_sigma_t5base_000204000.safetensors"
output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024-MS_t5base_raw"
te_aug_path = "/home/jaret/Dev/models/tmp/pixart_sigma_t5base_000227500.safetensors"
output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5base_raw"
print("Loading te adapter")
@@ -28,11 +28,13 @@ is_pixart = "pixart" in model_path.lower()
pipeline_class = StableDiffusionPipeline
transformer = PixArtTransformer2DModel.from_pretrained('PixArt-alpha/PixArt-Sigma-XL-2-512-MS', subfolder='transformer', torch_dtype=torch.float16)
if is_pixart:
pipeline_class = PixArtSigmaPipeline
if is_diffusers:
sd = pipeline_class.from_pretrained(model_path, torch_dtype=torch.float16)
sd = pipeline_class.from_pretrained(model_path, transformer=transformer, torch_dtype=torch.float16)
else:
sd = pipeline_class.from_single_file(model_path, torch_dtype=torch.float16)
@@ -113,6 +115,11 @@ for name in attn_processor_keys:
if weight_idx > 1000:
raise ValueError("Could not find the next weight")
orig_weight_shape_k = list(unet_sd[layer_name + ".to_k.weight"].shape)
new_weight_shape_k = list(te_aug_sd[te_aug_name + ".weight"].shape)
orig_weight_shape_v = list(unet_sd[layer_name + ".to_v.weight"].shape)
new_weight_shape_v = list(te_aug_sd[te_aug_name.replace('to_k', 'to_v') + ".weight"].shape)
unet_sd[layer_name + ".to_k.weight"] = te_aug_sd[te_aug_name + ".weight"]
unet_sd[layer_name + ".to_v.weight"] = te_aug_sd[te_aug_name.replace('to_k', 'to_v') + ".weight"]
@@ -120,6 +127,14 @@ for name in attn_processor_keys:
new_cross_attn_dim = unet_sd[layer_name + ".to_k.weight"].shape[1]
if is_pixart:
# copy the caption_projection weight
del unet_sd['caption_projection.linear_1.bias']
del unet_sd['caption_projection.linear_1.weight']
del unet_sd['caption_projection.linear_2.bias']
del unet_sd['caption_projection.linear_2.weight']
print("Saving unmodified model")
sd = sd.to("cpu", torch.float16)
sd.save_pretrained(
@@ -150,7 +165,7 @@ with open(os.path.join(unet_folder, "config.json"), 'r') as f:
config['cross_attention_dim'] = new_cross_attn_dim
if is_pixart:
config['caption_channels'] = te.config.d_model
config['caption_channels'] = None
# save it
with open(os.path.join(unet_folder, "config.json"), 'w') as f:

View File

@@ -1,31 +1,18 @@
import os
import torch
from transformers import T5EncoderModel, T5Tokenizer
from diffusers import StableDiffusionPipeline, UNet2DConditionModel, PixArtSigmaPipeline, Transformer2DModel
from safetensors.torch import load_file, save_file
from collections import OrderedDict
import json
# model_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_v01_000527000"
# te_path = "google/flan-t5-xl"
# te_aug_path = "/mnt/Train/out/ip_adapter/t5xx_sd15_v1/t5xx_sd15_v1_000032000.safetensors"
# output_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_t5xl_raw"
model_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model.safetensors"
output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model_tiny.safetensors"
te_aug_path = "/home/jaret/Dev/models/tmp/pixart_sigma_t5base_000204000.safetensors"
model_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model_orig.safetensors"
output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model.safetensors"
state_dict = load_file(model_path)
meta = OrderedDict()
meta["format"] = "pt"
# has 28 blocks
# keep block 0 and 27
new_state_dict = {}
# move non blocks over
# Move non-blocks over
for key, value in state_dict.items():
if not key.startswith("transformer_blocks."):
new_state_dict[key] = value
@@ -71,6 +58,5 @@ save_file(new_state_dict, output_path, metadata=meta)
new_param_count = sum([v.numel() for v in new_state_dict.values()])
old_param_count = sum([v.numel() for v in state_dict.values()])
# porint comma formatted
print(f"Old param count: {old_param_count:,}")
print(f"New param count: {new_param_count:,}")

81
testing/shrink_pixart2.py Normal file
View File

@@ -0,0 +1,81 @@
import torch
from safetensors.torch import load_file, save_file
from collections import OrderedDict
model_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model_orig.safetensors"
output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model.safetensors"
state_dict = load_file(model_path)
meta = OrderedDict()
meta["format"] = "pt"
new_state_dict = {}
# Move non-blocks over
for key, value in state_dict.items():
if not key.startswith("transformer_blocks."):
new_state_dict[key] = value
block_names = ['transformer_blocks.{idx}.attn1.to_k.bias', 'transformer_blocks.{idx}.attn1.to_k.weight',
'transformer_blocks.{idx}.attn1.to_out.0.bias', 'transformer_blocks.{idx}.attn1.to_out.0.weight',
'transformer_blocks.{idx}.attn1.to_q.bias', 'transformer_blocks.{idx}.attn1.to_q.weight',
'transformer_blocks.{idx}.attn1.to_v.bias', 'transformer_blocks.{idx}.attn1.to_v.weight',
'transformer_blocks.{idx}.attn2.to_k.bias', 'transformer_blocks.{idx}.attn2.to_k.weight',
'transformer_blocks.{idx}.attn2.to_out.0.bias', 'transformer_blocks.{idx}.attn2.to_out.0.weight',
'transformer_blocks.{idx}.attn2.to_q.bias', 'transformer_blocks.{idx}.attn2.to_q.weight',
'transformer_blocks.{idx}.attn2.to_v.bias', 'transformer_blocks.{idx}.attn2.to_v.weight',
'transformer_blocks.{idx}.ff.net.0.proj.bias', 'transformer_blocks.{idx}.ff.net.0.proj.weight',
'transformer_blocks.{idx}.ff.net.2.bias', 'transformer_blocks.{idx}.ff.net.2.weight',
'transformer_blocks.{idx}.scale_shift_table']
# Blocks to keep
# keep_blocks = [0, 1, 2, 6, 10, 14, 18, 22, 26, 27]
keep_blocks = [0, 1, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 27]
def weighted_merge(kept_block, removed_block, weight):
return kept_block * (1 - weight) + removed_block * weight
# First, copy all kept blocks to new_state_dict
for i, old_idx in enumerate(keep_blocks):
for name in block_names:
old_key = name.format(idx=old_idx)
new_key = name.format(idx=i)
new_state_dict[new_key] = state_dict[old_key].clone()
# Then, merge information from removed blocks
for i in range(28):
if i not in keep_blocks:
# Find the nearest kept blocks
prev_kept = max([b for b in keep_blocks if b < i])
next_kept = min([b for b in keep_blocks if b > i])
# Calculate the weight based on position
weight = (i - prev_kept) / (next_kept - prev_kept)
for name in block_names:
removed_key = name.format(idx=i)
prev_new_key = name.format(idx=keep_blocks.index(prev_kept))
next_new_key = name.format(idx=keep_blocks.index(next_kept))
# Weighted merge for previous kept block
new_state_dict[prev_new_key] = weighted_merge(new_state_dict[prev_new_key], state_dict[removed_key], weight)
# Weighted merge for next kept block
new_state_dict[next_new_key] = weighted_merge(new_state_dict[next_new_key], state_dict[removed_key],
1 - weight)
# Convert to fp16 and move to CPU
for key, value in new_state_dict.items():
new_state_dict[key] = value.to(torch.float16).cpu()
# Save the new state dict
save_file(new_state_dict, output_path, metadata=meta)
new_param_count = sum([v.numel() for v in new_state_dict.values()])
old_param_count = sum([v.numel() for v in state_dict.values()])
print(f"Old param count: {old_param_count:,}")
print(f"New param count: {new_param_count:,}")

View File

@@ -0,0 +1,63 @@
import torch
from safetensors.torch import load_file, save_file
from collections import OrderedDict
model_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_tiny/transformer/diffusion_pytorch_model.orig.safetensors"
output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_tiny/transformer/diffusion_pytorch_model.safetensors"
state_dict = load_file(model_path)
meta = OrderedDict()
meta["format"] = "pt"
new_state_dict = {}
# Move non-blocks over
for key, value in state_dict.items():
if not key.startswith("transformer_blocks."):
new_state_dict[key] = value
block_names = ['transformer_blocks.{idx}.attn1.to_k.bias', 'transformer_blocks.{idx}.attn1.to_k.weight',
'transformer_blocks.{idx}.attn1.to_out.0.bias', 'transformer_blocks.{idx}.attn1.to_out.0.weight',
'transformer_blocks.{idx}.attn1.to_q.bias', 'transformer_blocks.{idx}.attn1.to_q.weight',
'transformer_blocks.{idx}.attn1.to_v.bias', 'transformer_blocks.{idx}.attn1.to_v.weight',
'transformer_blocks.{idx}.attn2.to_k.bias', 'transformer_blocks.{idx}.attn2.to_k.weight',
'transformer_blocks.{idx}.attn2.to_out.0.bias', 'transformer_blocks.{idx}.attn2.to_out.0.weight',
'transformer_blocks.{idx}.attn2.to_q.bias', 'transformer_blocks.{idx}.attn2.to_q.weight',
'transformer_blocks.{idx}.attn2.to_v.bias', 'transformer_blocks.{idx}.attn2.to_v.weight',
'transformer_blocks.{idx}.ff.net.0.proj.bias', 'transformer_blocks.{idx}.ff.net.0.proj.weight',
'transformer_blocks.{idx}.ff.net.2.bias', 'transformer_blocks.{idx}.ff.net.2.weight',
'transformer_blocks.{idx}.scale_shift_table']
# New block idx 0, 1, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 27
current_idx = 0
for i in range(28):
if i not in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]:
# todo merge in with previous block
for name in block_names:
continue
# try:
# new_state_dict_key = name.format(idx=current_idx - 1)
# old_state_dict_key = name.format(idx=i)
# new_state_dict[new_state_dict_key] = (new_state_dict[new_state_dict_key] * 0.5) + (state_dict[old_state_dict_key] * 0.5)
# except KeyError:
# raise KeyError(f"KeyError: {name.format(idx=current_idx)}")
else:
for name in block_names:
new_state_dict[name.format(idx=current_idx)] = state_dict[name.format(idx=i)]
current_idx += 1
# make sure they are all fp16 and on cpu
for key, value in new_state_dict.items():
new_state_dict[key] = value.to(torch.float16).cpu()
# save the new state dict
save_file(new_state_dict, output_path, metadata=meta)
new_param_count = sum([v.numel() for v in new_state_dict.values()])
old_param_count = sum([v.numel() for v in state_dict.values()])
print(f"Old param count: {old_param_count:,}")
print(f"New param count: {new_param_count:,}")

View File

@@ -1775,16 +1775,16 @@ class StableDiffusion:
# saving in diffusers format
if not output_file.endswith('.safetensors'):
# diffusers
if self.is_pixart:
self.unet.save_pretrained(
save_directory=output_file,
safe_serialization=True,
)
else:
self.pipeline.save_pretrained(
save_directory=output_file,
safe_serialization=True,
)
# if self.is_pixart:
# self.unet.save_pretrained(
# save_directory=output_file,
# safe_serialization=True,
# )
# else:
self.pipeline.save_pretrained(
save_directory=output_file,
safe_serialization=True,
)
# save out meta config
meta_path = os.path.join(output_file, 'aitk_meta.yaml')
with open(meta_path, 'w') as f: