mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 08:29:45 +00:00
Save entire pixart model again
This commit is contained in:
@@ -2,7 +2,7 @@ import os
|
||||
|
||||
import torch
|
||||
from transformers import T5EncoderModel, T5Tokenizer
|
||||
from diffusers import StableDiffusionPipeline, UNet2DConditionModel, PixArtSigmaPipeline, Transformer2DModel
|
||||
from diffusers import StableDiffusionPipeline, UNet2DConditionModel, PixArtSigmaPipeline, Transformer2DModel, PixArtTransformer2DModel
|
||||
from safetensors.torch import load_file, save_file
|
||||
from collections import OrderedDict
|
||||
import json
|
||||
@@ -13,8 +13,8 @@ import json
|
||||
# output_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_t5xl_raw"
|
||||
model_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024-MS"
|
||||
te_path = "google/flan-t5-base"
|
||||
te_aug_path = "/home/jaret/Dev/models/tmp/pixart_sigma_t5base_000204000.safetensors"
|
||||
output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024-MS_t5base_raw"
|
||||
te_aug_path = "/home/jaret/Dev/models/tmp/pixart_sigma_t5base_000227500.safetensors"
|
||||
output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5base_raw"
|
||||
|
||||
|
||||
print("Loading te adapter")
|
||||
@@ -28,11 +28,13 @@ is_pixart = "pixart" in model_path.lower()
|
||||
|
||||
pipeline_class = StableDiffusionPipeline
|
||||
|
||||
transformer = PixArtTransformer2DModel.from_pretrained('PixArt-alpha/PixArt-Sigma-XL-2-512-MS', subfolder='transformer', torch_dtype=torch.float16)
|
||||
|
||||
if is_pixart:
|
||||
pipeline_class = PixArtSigmaPipeline
|
||||
|
||||
if is_diffusers:
|
||||
sd = pipeline_class.from_pretrained(model_path, torch_dtype=torch.float16)
|
||||
sd = pipeline_class.from_pretrained(model_path, transformer=transformer, torch_dtype=torch.float16)
|
||||
else:
|
||||
sd = pipeline_class.from_single_file(model_path, torch_dtype=torch.float16)
|
||||
|
||||
@@ -113,6 +115,11 @@ for name in attn_processor_keys:
|
||||
if weight_idx > 1000:
|
||||
raise ValueError("Could not find the next weight")
|
||||
|
||||
orig_weight_shape_k = list(unet_sd[layer_name + ".to_k.weight"].shape)
|
||||
new_weight_shape_k = list(te_aug_sd[te_aug_name + ".weight"].shape)
|
||||
orig_weight_shape_v = list(unet_sd[layer_name + ".to_v.weight"].shape)
|
||||
new_weight_shape_v = list(te_aug_sd[te_aug_name.replace('to_k', 'to_v') + ".weight"].shape)
|
||||
|
||||
unet_sd[layer_name + ".to_k.weight"] = te_aug_sd[te_aug_name + ".weight"]
|
||||
unet_sd[layer_name + ".to_v.weight"] = te_aug_sd[te_aug_name.replace('to_k', 'to_v') + ".weight"]
|
||||
|
||||
@@ -120,6 +127,14 @@ for name in attn_processor_keys:
|
||||
new_cross_attn_dim = unet_sd[layer_name + ".to_k.weight"].shape[1]
|
||||
|
||||
|
||||
|
||||
if is_pixart:
|
||||
# copy the caption_projection weight
|
||||
del unet_sd['caption_projection.linear_1.bias']
|
||||
del unet_sd['caption_projection.linear_1.weight']
|
||||
del unet_sd['caption_projection.linear_2.bias']
|
||||
del unet_sd['caption_projection.linear_2.weight']
|
||||
|
||||
print("Saving unmodified model")
|
||||
sd = sd.to("cpu", torch.float16)
|
||||
sd.save_pretrained(
|
||||
@@ -150,7 +165,7 @@ with open(os.path.join(unet_folder, "config.json"), 'r') as f:
|
||||
config['cross_attention_dim'] = new_cross_attn_dim
|
||||
|
||||
if is_pixart:
|
||||
config['caption_channels'] = te.config.d_model
|
||||
config['caption_channels'] = None
|
||||
|
||||
# save it
|
||||
with open(os.path.join(unet_folder, "config.json"), 'w') as f:
|
||||
|
||||
@@ -1,31 +1,18 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
from transformers import T5EncoderModel, T5Tokenizer
|
||||
from diffusers import StableDiffusionPipeline, UNet2DConditionModel, PixArtSigmaPipeline, Transformer2DModel
|
||||
from safetensors.torch import load_file, save_file
|
||||
from collections import OrderedDict
|
||||
import json
|
||||
|
||||
# model_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_v01_000527000"
|
||||
# te_path = "google/flan-t5-xl"
|
||||
# te_aug_path = "/mnt/Train/out/ip_adapter/t5xx_sd15_v1/t5xx_sd15_v1_000032000.safetensors"
|
||||
# output_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_t5xl_raw"
|
||||
model_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model.safetensors"
|
||||
output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model_tiny.safetensors"
|
||||
te_aug_path = "/home/jaret/Dev/models/tmp/pixart_sigma_t5base_000204000.safetensors"
|
||||
model_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model_orig.safetensors"
|
||||
output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model.safetensors"
|
||||
|
||||
state_dict = load_file(model_path)
|
||||
|
||||
meta = OrderedDict()
|
||||
meta["format"] = "pt"
|
||||
|
||||
# has 28 blocks
|
||||
# keep block 0 and 27
|
||||
|
||||
new_state_dict = {}
|
||||
|
||||
# move non blocks over
|
||||
# Move non-blocks over
|
||||
for key, value in state_dict.items():
|
||||
if not key.startswith("transformer_blocks."):
|
||||
new_state_dict[key] = value
|
||||
@@ -71,6 +58,5 @@ save_file(new_state_dict, output_path, metadata=meta)
|
||||
new_param_count = sum([v.numel() for v in new_state_dict.values()])
|
||||
old_param_count = sum([v.numel() for v in state_dict.values()])
|
||||
|
||||
# porint comma formatted
|
||||
print(f"Old param count: {old_param_count:,}")
|
||||
print(f"New param count: {new_param_count:,}")
|
||||
81
testing/shrink_pixart2.py
Normal file
81
testing/shrink_pixart2.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from collections import OrderedDict
|
||||
|
||||
model_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model_orig.safetensors"
|
||||
output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model.safetensors"
|
||||
|
||||
state_dict = load_file(model_path)
|
||||
|
||||
meta = OrderedDict()
|
||||
meta["format"] = "pt"
|
||||
|
||||
new_state_dict = {}
|
||||
|
||||
# Move non-blocks over
|
||||
for key, value in state_dict.items():
|
||||
if not key.startswith("transformer_blocks."):
|
||||
new_state_dict[key] = value
|
||||
|
||||
block_names = ['transformer_blocks.{idx}.attn1.to_k.bias', 'transformer_blocks.{idx}.attn1.to_k.weight',
|
||||
'transformer_blocks.{idx}.attn1.to_out.0.bias', 'transformer_blocks.{idx}.attn1.to_out.0.weight',
|
||||
'transformer_blocks.{idx}.attn1.to_q.bias', 'transformer_blocks.{idx}.attn1.to_q.weight',
|
||||
'transformer_blocks.{idx}.attn1.to_v.bias', 'transformer_blocks.{idx}.attn1.to_v.weight',
|
||||
'transformer_blocks.{idx}.attn2.to_k.bias', 'transformer_blocks.{idx}.attn2.to_k.weight',
|
||||
'transformer_blocks.{idx}.attn2.to_out.0.bias', 'transformer_blocks.{idx}.attn2.to_out.0.weight',
|
||||
'transformer_blocks.{idx}.attn2.to_q.bias', 'transformer_blocks.{idx}.attn2.to_q.weight',
|
||||
'transformer_blocks.{idx}.attn2.to_v.bias', 'transformer_blocks.{idx}.attn2.to_v.weight',
|
||||
'transformer_blocks.{idx}.ff.net.0.proj.bias', 'transformer_blocks.{idx}.ff.net.0.proj.weight',
|
||||
'transformer_blocks.{idx}.ff.net.2.bias', 'transformer_blocks.{idx}.ff.net.2.weight',
|
||||
'transformer_blocks.{idx}.scale_shift_table']
|
||||
|
||||
# Blocks to keep
|
||||
# keep_blocks = [0, 1, 2, 6, 10, 14, 18, 22, 26, 27]
|
||||
keep_blocks = [0, 1, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 27]
|
||||
|
||||
|
||||
def weighted_merge(kept_block, removed_block, weight):
|
||||
return kept_block * (1 - weight) + removed_block * weight
|
||||
|
||||
|
||||
# First, copy all kept blocks to new_state_dict
|
||||
for i, old_idx in enumerate(keep_blocks):
|
||||
for name in block_names:
|
||||
old_key = name.format(idx=old_idx)
|
||||
new_key = name.format(idx=i)
|
||||
new_state_dict[new_key] = state_dict[old_key].clone()
|
||||
|
||||
# Then, merge information from removed blocks
|
||||
for i in range(28):
|
||||
if i not in keep_blocks:
|
||||
# Find the nearest kept blocks
|
||||
prev_kept = max([b for b in keep_blocks if b < i])
|
||||
next_kept = min([b for b in keep_blocks if b > i])
|
||||
|
||||
# Calculate the weight based on position
|
||||
weight = (i - prev_kept) / (next_kept - prev_kept)
|
||||
|
||||
for name in block_names:
|
||||
removed_key = name.format(idx=i)
|
||||
prev_new_key = name.format(idx=keep_blocks.index(prev_kept))
|
||||
next_new_key = name.format(idx=keep_blocks.index(next_kept))
|
||||
|
||||
# Weighted merge for previous kept block
|
||||
new_state_dict[prev_new_key] = weighted_merge(new_state_dict[prev_new_key], state_dict[removed_key], weight)
|
||||
|
||||
# Weighted merge for next kept block
|
||||
new_state_dict[next_new_key] = weighted_merge(new_state_dict[next_new_key], state_dict[removed_key],
|
||||
1 - weight)
|
||||
|
||||
# Convert to fp16 and move to CPU
|
||||
for key, value in new_state_dict.items():
|
||||
new_state_dict[key] = value.to(torch.float16).cpu()
|
||||
|
||||
# Save the new state dict
|
||||
save_file(new_state_dict, output_path, metadata=meta)
|
||||
|
||||
new_param_count = sum([v.numel() for v in new_state_dict.values()])
|
||||
old_param_count = sum([v.numel() for v in state_dict.values()])
|
||||
|
||||
print(f"Old param count: {old_param_count:,}")
|
||||
print(f"New param count: {new_param_count:,}")
|
||||
63
testing/shrink_pixart_sm.py
Normal file
63
testing/shrink_pixart_sm.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from collections import OrderedDict
|
||||
|
||||
model_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_tiny/transformer/diffusion_pytorch_model.orig.safetensors"
|
||||
output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_tiny/transformer/diffusion_pytorch_model.safetensors"
|
||||
|
||||
state_dict = load_file(model_path)
|
||||
|
||||
meta = OrderedDict()
|
||||
meta["format"] = "pt"
|
||||
|
||||
new_state_dict = {}
|
||||
|
||||
# Move non-blocks over
|
||||
for key, value in state_dict.items():
|
||||
if not key.startswith("transformer_blocks."):
|
||||
new_state_dict[key] = value
|
||||
|
||||
block_names = ['transformer_blocks.{idx}.attn1.to_k.bias', 'transformer_blocks.{idx}.attn1.to_k.weight',
|
||||
'transformer_blocks.{idx}.attn1.to_out.0.bias', 'transformer_blocks.{idx}.attn1.to_out.0.weight',
|
||||
'transformer_blocks.{idx}.attn1.to_q.bias', 'transformer_blocks.{idx}.attn1.to_q.weight',
|
||||
'transformer_blocks.{idx}.attn1.to_v.bias', 'transformer_blocks.{idx}.attn1.to_v.weight',
|
||||
'transformer_blocks.{idx}.attn2.to_k.bias', 'transformer_blocks.{idx}.attn2.to_k.weight',
|
||||
'transformer_blocks.{idx}.attn2.to_out.0.bias', 'transformer_blocks.{idx}.attn2.to_out.0.weight',
|
||||
'transformer_blocks.{idx}.attn2.to_q.bias', 'transformer_blocks.{idx}.attn2.to_q.weight',
|
||||
'transformer_blocks.{idx}.attn2.to_v.bias', 'transformer_blocks.{idx}.attn2.to_v.weight',
|
||||
'transformer_blocks.{idx}.ff.net.0.proj.bias', 'transformer_blocks.{idx}.ff.net.0.proj.weight',
|
||||
'transformer_blocks.{idx}.ff.net.2.bias', 'transformer_blocks.{idx}.ff.net.2.weight',
|
||||
'transformer_blocks.{idx}.scale_shift_table']
|
||||
|
||||
# New block idx 0, 1, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 27
|
||||
|
||||
current_idx = 0
|
||||
for i in range(28):
|
||||
if i not in [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]:
|
||||
# todo merge in with previous block
|
||||
for name in block_names:
|
||||
continue
|
||||
# try:
|
||||
# new_state_dict_key = name.format(idx=current_idx - 1)
|
||||
# old_state_dict_key = name.format(idx=i)
|
||||
# new_state_dict[new_state_dict_key] = (new_state_dict[new_state_dict_key] * 0.5) + (state_dict[old_state_dict_key] * 0.5)
|
||||
# except KeyError:
|
||||
# raise KeyError(f"KeyError: {name.format(idx=current_idx)}")
|
||||
else:
|
||||
for name in block_names:
|
||||
new_state_dict[name.format(idx=current_idx)] = state_dict[name.format(idx=i)]
|
||||
current_idx += 1
|
||||
|
||||
|
||||
# make sure they are all fp16 and on cpu
|
||||
for key, value in new_state_dict.items():
|
||||
new_state_dict[key] = value.to(torch.float16).cpu()
|
||||
|
||||
# save the new state dict
|
||||
save_file(new_state_dict, output_path, metadata=meta)
|
||||
|
||||
new_param_count = sum([v.numel() for v in new_state_dict.values()])
|
||||
old_param_count = sum([v.numel() for v in state_dict.values()])
|
||||
|
||||
print(f"Old param count: {old_param_count:,}")
|
||||
print(f"New param count: {new_param_count:,}")
|
||||
@@ -1775,16 +1775,16 @@ class StableDiffusion:
|
||||
# saving in diffusers format
|
||||
if not output_file.endswith('.safetensors'):
|
||||
# diffusers
|
||||
if self.is_pixart:
|
||||
self.unet.save_pretrained(
|
||||
save_directory=output_file,
|
||||
safe_serialization=True,
|
||||
)
|
||||
else:
|
||||
self.pipeline.save_pretrained(
|
||||
save_directory=output_file,
|
||||
safe_serialization=True,
|
||||
)
|
||||
# if self.is_pixart:
|
||||
# self.unet.save_pretrained(
|
||||
# save_directory=output_file,
|
||||
# safe_serialization=True,
|
||||
# )
|
||||
# else:
|
||||
self.pipeline.save_pretrained(
|
||||
save_directory=output_file,
|
||||
safe_serialization=True,
|
||||
)
|
||||
# save out meta config
|
||||
meta_path = os.path.join(output_file, 'aitk_meta.yaml')
|
||||
with open(meta_path, 'w') as f:
|
||||
|
||||
Reference in New Issue
Block a user