Bug fixes. allow for random negative prompts

This commit is contained in:
Jaret Burkett
2024-02-21 04:51:52 -07:00
parent 2478554c95
commit 49c41e6a5f
8 changed files with 166 additions and 6 deletions

View File

@@ -1,5 +1,9 @@
import argparse
from collections import OrderedDict
import sys
import os
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(ROOT_DIR)
import torch

View File

@@ -0,0 +1,42 @@
import torch
from safetensors.torch import save_file, load_file
from collections import OrderedDict
meta = OrderedDict()
meta["format"] ="pt"
attn_dict = load_file("/mnt/Train/out/ip_adapter/sd15_bigG/sd15_bigG_000266000.safetensors")
state_dict = load_file("/home/jaret/Dev/models/hf/OstrisDiffusionV1/unet/diffusion_pytorch_model.safetensors")
attn_list = []
for key, value in state_dict.items():
if "attn1" in key:
attn_list.append(key)
attn_names = ['down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor', 'down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor', 'down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor', 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor', 'down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor', 'down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor', 'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor', 'up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor', 'up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor', 'up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor', 'up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor', 'up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor', 'up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor', 'up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor', 'up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor', 'mid_block.attentions.0.transformer_blocks.0.attn2.processor']
adapter_names = []
for i in range(100):
if f'te_adapter.adapter_modules.{i}.to_k_adapter.weight' in attn_dict:
adapter_names.append(f"te_adapter.adapter_modules.{i}.adapter")
for i in range(len(adapter_names)):
adapter_name = adapter_names[i]
attn_name = attn_names[i]
adapter_k_name = adapter_name[:-8] + '.to_k_adapter.weight'
adapter_v_name = adapter_name[:-8] + '.to_v_adapter.weight'
state_k_name = attn_name.replace(".processor", ".to_k.weight")
state_v_name = attn_name.replace(".processor", ".to_v.weight")
if adapter_k_name in attn_dict:
state_dict[state_k_name] = attn_dict[adapter_k_name]
state_dict[state_v_name] = attn_dict[adapter_v_name]
else:
print("adapter_k_name", adapter_k_name)
print("state_k_name", state_k_name)
for key, value in state_dict.items():
state_dict[key] = value.cpu().to(torch.float16)
save_file(state_dict, "/home/jaret/Dev/models/hf/OstrisDiffusionV1/unet/diffusion_pytorch_model.safetensors", metadata=meta)
print("Done")

View File

@@ -0,0 +1,65 @@
import argparse
from PIL import Image
from PIL.ImageOps import exif_transpose
from tqdm import tqdm
import os
parser = argparse.ArgumentParser(description='Process some images.')
parser.add_argument("input_folder", type=str, help="Path to folder containing images")
args = parser.parse_args()
img_types = ['.jpg', '.jpeg', '.png', '.webp']
# find all images in the input folder
images = []
for root, _, files in os.walk(args.input_folder):
for file in files:
if file.lower().endswith(tuple(img_types)):
images.append(os.path.join(root, file))
print(f"Found {len(images)} images")
num_skipped = 0
num_repaired = 0
num_deleted = 0
pbar = tqdm(total=len(images), desc=f"Repaired {num_repaired} images", unit="image")
for img_path in images:
filename = os.path.basename(img_path)
filename_no_ext, file_extension = os.path.splitext(filename)
# if it is jpg, ignore
if file_extension.lower() == '.jpg':
num_skipped += 1
pbar.update(1)
continue
try:
img = Image.open(img_path)
except Exception as e:
print(f"Error opening {img_path}: {e}")
# delete it
os.remove(img_path)
num_deleted += 1
pbar.update(1)
pbar.set_description(f"Repaired {num_repaired} images, Skipped {num_skipped}, Deleted {num_deleted}")
continue
try:
img = exif_transpose(img)
except Exception as e:
print(f"Error rotating {img_path}: {e}")
new_path = os.path.join(os.path.dirname(img_path), filename_no_ext + '.jpg')
img = img.convert("RGB")
img.save(new_path, quality=95)
# remove the old file
os.remove(img_path)
num_repaired += 1
pbar.update(1)
# update pbar
pbar.set_description(f"Repaired {num_repaired} images, Skipped {num_skipped}, Deleted {num_deleted}")
print("Done")