mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
181 lines
6.0 KiB
Python
181 lines
6.0 KiB
Python
import os
|
|
|
|
import torch
|
|
from transformers import T5EncoderModel, T5Tokenizer
|
|
from diffusers import StableDiffusionPipeline, UNet2DConditionModel, PixArtSigmaPipeline, Transformer2DModel, PixArtTransformer2DModel
|
|
from safetensors.torch import load_file, save_file
|
|
from collections import OrderedDict
|
|
import json
|
|
|
|
# model_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_v01_000527000"
|
|
# te_path = "google/flan-t5-xl"
|
|
# te_aug_path = "/mnt/Train/out/ip_adapter/t5xx_sd15_v1/t5xx_sd15_v1_000032000.safetensors"
|
|
# output_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_t5xl_raw"
|
|
model_path = "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS"
|
|
te_path = "google/flan-t5-large"
|
|
te_aug_path = "/home/jaret/Dev/models/tmp/pixart_sigma_t5l_000034000.safetensors"
|
|
output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-512_MS_t5large_raw"
|
|
|
|
|
|
print("Loading te adapter")
|
|
te_aug_sd = load_file(te_aug_path)
|
|
|
|
print("Loading model")
|
|
is_diffusers = (not os.path.exists(model_path)) or os.path.isdir(model_path)
|
|
|
|
# if "pixart" in model_path.lower():
|
|
is_pixart = "pixart" in model_path.lower()
|
|
|
|
pipeline_class = StableDiffusionPipeline
|
|
|
|
transformer = PixArtTransformer2DModel.from_pretrained('PixArt-alpha/PixArt-Sigma-XL-2-512-MS', subfolder='transformer', torch_dtype=torch.float16)
|
|
|
|
if is_pixart:
|
|
pipeline_class = PixArtSigmaPipeline
|
|
|
|
if is_diffusers:
|
|
sd = pipeline_class.from_pretrained(model_path, transformer=transformer, torch_dtype=torch.float16)
|
|
else:
|
|
sd = pipeline_class.from_single_file(model_path, torch_dtype=torch.float16)
|
|
|
|
print("Loading Text Encoder")
|
|
# Load the text encoder
|
|
te = T5EncoderModel.from_pretrained(te_path, torch_dtype=torch.float16)
|
|
|
|
# patch it
|
|
sd.text_encoder = te
|
|
sd.tokenizer = T5Tokenizer.from_pretrained(te_path)
|
|
|
|
if is_pixart:
|
|
unet = sd.transformer
|
|
unet_sd = sd.transformer.state_dict()
|
|
else:
|
|
unet = sd.transformer
|
|
unet_sd = sd.unet.state_dict()
|
|
|
|
|
|
if is_pixart:
|
|
weight_idx = 0
|
|
else:
|
|
weight_idx = 1
|
|
|
|
new_cross_attn_dim = None
|
|
|
|
# count the num of params in state dict
|
|
start_params = sum([v.numel() for v in unet_sd.values()])
|
|
|
|
print("Building")
|
|
attn_processor_keys = []
|
|
if is_pixart:
|
|
transformer: Transformer2DModel = unet
|
|
for i, module in transformer.transformer_blocks.named_children():
|
|
attn_processor_keys.append(f"transformer_blocks.{i}.attn1")
|
|
# cross attention
|
|
attn_processor_keys.append(f"transformer_blocks.{i}.attn2")
|
|
else:
|
|
attn_processor_keys = list(unet.attn_processors.keys())
|
|
|
|
for name in attn_processor_keys:
|
|
cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") or name.endswith(
|
|
"attn1") else \
|
|
unet.config['cross_attention_dim']
|
|
if name.startswith("mid_block"):
|
|
hidden_size = unet.config['block_out_channels'][-1]
|
|
elif name.startswith("up_blocks"):
|
|
block_id = int(name[len("up_blocks.")])
|
|
hidden_size = list(reversed(unet.config['block_out_channels']))[block_id]
|
|
elif name.startswith("down_blocks"):
|
|
block_id = int(name[len("down_blocks.")])
|
|
hidden_size = unet.config['block_out_channels'][block_id]
|
|
elif name.startswith("transformer"):
|
|
hidden_size = unet.config['cross_attention_dim']
|
|
else:
|
|
# they didnt have this, but would lead to undefined below
|
|
raise ValueError(f"unknown attn processor name: {name}")
|
|
if cross_attention_dim is None:
|
|
pass
|
|
else:
|
|
layer_name = name.split(".processor")[0]
|
|
to_k_adapter = unet_sd[layer_name + ".to_k.weight"]
|
|
to_v_adapter = unet_sd[layer_name + ".to_v.weight"]
|
|
|
|
te_aug_name = None
|
|
while True:
|
|
if is_pixart:
|
|
te_aug_name = f"te_adapter.adapter_modules.{weight_idx}.to_k_adapter"
|
|
else:
|
|
te_aug_name = f"te_adapter.adapter_modules.{weight_idx}.to_k_adapter"
|
|
if f"{te_aug_name}.weight" in te_aug_sd:
|
|
# increment so we dont redo it next time
|
|
weight_idx += 1
|
|
break
|
|
else:
|
|
weight_idx += 1
|
|
|
|
if weight_idx > 1000:
|
|
raise ValueError("Could not find the next weight")
|
|
|
|
orig_weight_shape_k = list(unet_sd[layer_name + ".to_k.weight"].shape)
|
|
new_weight_shape_k = list(te_aug_sd[te_aug_name + ".weight"].shape)
|
|
orig_weight_shape_v = list(unet_sd[layer_name + ".to_v.weight"].shape)
|
|
new_weight_shape_v = list(te_aug_sd[te_aug_name.replace('to_k', 'to_v') + ".weight"].shape)
|
|
|
|
unet_sd[layer_name + ".to_k.weight"] = te_aug_sd[te_aug_name + ".weight"]
|
|
unet_sd[layer_name + ".to_v.weight"] = te_aug_sd[te_aug_name.replace('to_k', 'to_v') + ".weight"]
|
|
|
|
if new_cross_attn_dim is None:
|
|
new_cross_attn_dim = unet_sd[layer_name + ".to_k.weight"].shape[1]
|
|
|
|
|
|
|
|
if is_pixart:
|
|
# copy the caption_projection weight
|
|
del unet_sd['caption_projection.linear_1.bias']
|
|
del unet_sd['caption_projection.linear_1.weight']
|
|
del unet_sd['caption_projection.linear_2.bias']
|
|
del unet_sd['caption_projection.linear_2.weight']
|
|
|
|
print("Saving unmodified model")
|
|
sd = sd.to("cpu", torch.float16)
|
|
sd.save_pretrained(
|
|
output_path,
|
|
safe_serialization=True,
|
|
)
|
|
|
|
# overwrite the unet
|
|
if is_pixart:
|
|
unet_folder = os.path.join(output_path, "transformer")
|
|
else:
|
|
unet_folder = os.path.join(output_path, "unet")
|
|
|
|
# move state_dict to cpu
|
|
unet_sd = {k: v.clone().cpu().to(torch.float16) for k, v in unet_sd.items()}
|
|
|
|
meta = OrderedDict()
|
|
meta["format"] = "pt"
|
|
|
|
print("Patching")
|
|
|
|
save_file(unet_sd, os.path.join(unet_folder, "diffusion_pytorch_model.safetensors"), meta)
|
|
|
|
# load the json file
|
|
with open(os.path.join(unet_folder, "config.json"), 'r') as f:
|
|
config = json.load(f)
|
|
|
|
config['cross_attention_dim'] = new_cross_attn_dim
|
|
|
|
if is_pixart:
|
|
config['caption_channels'] = None
|
|
|
|
# save it
|
|
with open(os.path.join(unet_folder, "config.json"), 'w') as f:
|
|
json.dump(config, f, indent=2)
|
|
|
|
print("Done")
|
|
|
|
new_params = sum([v.numel() for v in unet_sd.values()])
|
|
|
|
# print new and old params with , formatted
|
|
print(f"Old params: {start_params:,}")
|
|
print(f"New params: {new_params:,}")
|