mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 03:01:28 +00:00
Save entire pixart model again
This commit is contained in:
@@ -1,31 +1,18 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
from transformers import T5EncoderModel, T5Tokenizer
|
||||
from diffusers import StableDiffusionPipeline, UNet2DConditionModel, PixArtSigmaPipeline, Transformer2DModel
|
||||
from safetensors.torch import load_file, save_file
|
||||
from collections import OrderedDict
|
||||
import json
|
||||
|
||||
# model_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_v01_000527000"
|
||||
# te_path = "google/flan-t5-xl"
|
||||
# te_aug_path = "/mnt/Train/out/ip_adapter/t5xx_sd15_v1/t5xx_sd15_v1_000032000.safetensors"
|
||||
# output_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_t5xl_raw"
|
||||
model_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model.safetensors"
|
||||
output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model_tiny.safetensors"
|
||||
te_aug_path = "/home/jaret/Dev/models/tmp/pixart_sigma_t5base_000204000.safetensors"
|
||||
model_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model_orig.safetensors"
|
||||
output_path = "/home/jaret/Dev/models/hf/PixArt-Sigma-XL-2-1024_tiny/transformer/diffusion_pytorch_model.safetensors"
|
||||
|
||||
state_dict = load_file(model_path)
|
||||
|
||||
meta = OrderedDict()
|
||||
meta["format"] = "pt"
|
||||
|
||||
# has 28 blocks
|
||||
# keep block 0 and 27
|
||||
|
||||
new_state_dict = {}
|
||||
|
||||
# move non blocks over
|
||||
# Move non-blocks over
|
||||
for key, value in state_dict.items():
|
||||
if not key.startswith("transformer_blocks."):
|
||||
new_state_dict[key] = value
|
||||
@@ -71,6 +58,5 @@ save_file(new_state_dict, output_path, metadata=meta)
|
||||
new_param_count = sum([v.numel() for v in new_state_dict.values()])
|
||||
old_param_count = sum([v.numel() for v in state_dict.values()])
|
||||
|
||||
# porint comma formatted
|
||||
print(f"Old param count: {old_param_count:,}")
|
||||
print(f"New param count: {new_param_count:,}")
|
||||
Reference in New Issue
Block a user