mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Save entire pixart model again
This commit is contained in:
81
testing/shrink_pixart2.py
Normal file
81
testing/shrink_pixart2.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
from collections import OrderedDict
|
||||
|
||||
model_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model_orig.safetensors"
|
||||
output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model.safetensors"
|
||||
|
||||
state_dict = load_file(model_path)
|
||||
|
||||
meta = OrderedDict()
|
||||
meta["format"] = "pt"
|
||||
|
||||
new_state_dict = {}
|
||||
|
||||
# Move non-blocks over
|
||||
for key, value in state_dict.items():
|
||||
if not key.startswith("transformer_blocks."):
|
||||
new_state_dict[key] = value
|
||||
|
||||
block_names = ['transformer_blocks.{idx}.attn1.to_k.bias', 'transformer_blocks.{idx}.attn1.to_k.weight',
|
||||
'transformer_blocks.{idx}.attn1.to_out.0.bias', 'transformer_blocks.{idx}.attn1.to_out.0.weight',
|
||||
'transformer_blocks.{idx}.attn1.to_q.bias', 'transformer_blocks.{idx}.attn1.to_q.weight',
|
||||
'transformer_blocks.{idx}.attn1.to_v.bias', 'transformer_blocks.{idx}.attn1.to_v.weight',
|
||||
'transformer_blocks.{idx}.attn2.to_k.bias', 'transformer_blocks.{idx}.attn2.to_k.weight',
|
||||
'transformer_blocks.{idx}.attn2.to_out.0.bias', 'transformer_blocks.{idx}.attn2.to_out.0.weight',
|
||||
'transformer_blocks.{idx}.attn2.to_q.bias', 'transformer_blocks.{idx}.attn2.to_q.weight',
|
||||
'transformer_blocks.{idx}.attn2.to_v.bias', 'transformer_blocks.{idx}.attn2.to_v.weight',
|
||||
'transformer_blocks.{idx}.ff.net.0.proj.bias', 'transformer_blocks.{idx}.ff.net.0.proj.weight',
|
||||
'transformer_blocks.{idx}.ff.net.2.bias', 'transformer_blocks.{idx}.ff.net.2.weight',
|
||||
'transformer_blocks.{idx}.scale_shift_table']
|
||||
|
||||
# Blocks to keep
|
||||
# keep_blocks = [0, 1, 2, 6, 10, 14, 18, 22, 26, 27]
|
||||
keep_blocks = [0, 1, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 27]
|
||||
|
||||
|
||||
def weighted_merge(kept_block, removed_block, weight):
|
||||
return kept_block * (1 - weight) + removed_block * weight
|
||||
|
||||
|
||||
# First, copy all kept blocks to new_state_dict
|
||||
for i, old_idx in enumerate(keep_blocks):
|
||||
for name in block_names:
|
||||
old_key = name.format(idx=old_idx)
|
||||
new_key = name.format(idx=i)
|
||||
new_state_dict[new_key] = state_dict[old_key].clone()
|
||||
|
||||
# Then, merge information from removed blocks
|
||||
for i in range(28):
|
||||
if i not in keep_blocks:
|
||||
# Find the nearest kept blocks
|
||||
prev_kept = max([b for b in keep_blocks if b < i])
|
||||
next_kept = min([b for b in keep_blocks if b > i])
|
||||
|
||||
# Calculate the weight based on position
|
||||
weight = (i - prev_kept) / (next_kept - prev_kept)
|
||||
|
||||
for name in block_names:
|
||||
removed_key = name.format(idx=i)
|
||||
prev_new_key = name.format(idx=keep_blocks.index(prev_kept))
|
||||
next_new_key = name.format(idx=keep_blocks.index(next_kept))
|
||||
|
||||
# Weighted merge for previous kept block
|
||||
new_state_dict[prev_new_key] = weighted_merge(new_state_dict[prev_new_key], state_dict[removed_key], weight)
|
||||
|
||||
# Weighted merge for next kept block
|
||||
new_state_dict[next_new_key] = weighted_merge(new_state_dict[next_new_key], state_dict[removed_key],
|
||||
1 - weight)
|
||||
|
||||
# Convert to fp16 and move to CPU
|
||||
for key, value in new_state_dict.items():
|
||||
new_state_dict[key] = value.to(torch.float16).cpu()
|
||||
|
||||
# Save the new state dict
|
||||
save_file(new_state_dict, output_path, metadata=meta)
|
||||
|
||||
new_param_count = sum([v.numel() for v in new_state_dict.values()])
|
||||
old_param_count = sum([v.numel() for v in state_dict.values()])
|
||||
|
||||
print(f"Old param count: {old_param_count:,}")
|
||||
print(f"New param count: {new_param_count:,}")
|
||||
Reference in New Issue
Block a user