mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-25 16:59:22 +00:00
Minor fixes
This commit is contained in:
@@ -7,16 +7,21 @@ from safetensors.torch import load_file, save_file
|
||||
from collections import OrderedDict
|
||||
import json
|
||||
|
||||
model_path = "/mnt/Models/stable-diffusion/models/stable-diffusion/Ostris/objective_reality_v2.safetensors"
|
||||
model_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_v01_000527000"
|
||||
te_path = "google/flan-t5-xl"
|
||||
te_aug_path = "/mnt/Train/out/ip_adapter/t5xx_sd15_v1/t5xx_sd15_v1_000032000.safetensors"
|
||||
output_path = "/home/jaret/Dev/models/hf/t5xl_sd15_v1"
|
||||
output_path = "/home/jaret/Dev/models/hf/kl-f16-d42_sd15_t5xl_raw"
|
||||
|
||||
print("Loading te adapter")
|
||||
te_aug_sd = load_file(te_aug_path)
|
||||
|
||||
print("Loading model")
|
||||
sd = StableDiffusionPipeline.from_single_file(model_path, torch_dtype=torch.float16)
|
||||
is_diffusers = (not os.path.exists(model_path)) or os.path.isdir(model_path)
|
||||
|
||||
if is_diffusers:
|
||||
sd = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
|
||||
else:
|
||||
sd = StableDiffusionPipeline.from_single_file(model_path, torch_dtype=torch.float16)
|
||||
|
||||
print("Loading Text Encoder")
|
||||
# Load the text encoder
|
||||
@@ -74,6 +79,7 @@ for name in sd.unet.attn_processors.keys():
|
||||
|
||||
|
||||
print("Saving unmodified model")
|
||||
sd = sd.to("cpu", torch.float16)
|
||||
sd.save_pretrained(
|
||||
output_path,
|
||||
safe_serialization=True,
|
||||
|
||||
Reference in New Issue
Block a user