mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-01 03:31:35 +00:00
Bug fixes. allow for random negative prompts
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
import os
|
||||||
import random
|
import random
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Union, Literal, List, Optional
|
from typing import Union, Literal, List, Optional
|
||||||
@@ -51,6 +52,8 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
self.taesd: Optional[AutoencoderTiny] = None
|
self.taesd: Optional[AutoencoderTiny] = None
|
||||||
|
|
||||||
self._clip_image_embeds_unconditional: Union[List[str], None] = None
|
self._clip_image_embeds_unconditional: Union[List[str], None] = None
|
||||||
|
self.negative_prompt_pool: Union[List[str], None] = None
|
||||||
|
self.batch_negative_prompt: Union[List[str], None] = None
|
||||||
|
|
||||||
def before_model_load(self):
|
def before_model_load(self):
|
||||||
pass
|
pass
|
||||||
@@ -108,6 +111,16 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
|
|
||||||
self._clip_image_embeds_unconditional = unconditional_clip_image_embeds
|
self._clip_image_embeds_unconditional = unconditional_clip_image_embeds
|
||||||
|
|
||||||
|
if self.train_config.negative_prompt is not None:
|
||||||
|
if os.path.exists(self.train_config.negative_prompt):
|
||||||
|
with open(self.train_config.negative_prompt, 'r') as f:
|
||||||
|
self.negative_prompt_pool = f.readlines()
|
||||||
|
# remove empty
|
||||||
|
self.negative_prompt_pool = [x.strip() for x in self.negative_prompt_pool if x.strip() != ""]
|
||||||
|
else:
|
||||||
|
# single prompt
|
||||||
|
self.negative_prompt_pool = [self.train_config.negative_prompt]
|
||||||
|
|
||||||
def process_output_for_turbo(self, pred, noisy_latents, timesteps, noise, batch):
|
def process_output_for_turbo(self, pred, noisy_latents, timesteps, noise, batch):
|
||||||
# to process turbo learning, we make one big step from our current timestep to the end
|
# to process turbo learning, we make one big step from our current timestep to the end
|
||||||
# we then denoise the prediction on that remaining step and target our loss to our target latents
|
# we then denoise the prediction on that remaining step and target our loss to our target latents
|
||||||
@@ -781,6 +794,18 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
batch = self.preprocess_batch(batch)
|
batch = self.preprocess_batch(batch)
|
||||||
dtype = get_torch_dtype(self.train_config.dtype)
|
dtype = get_torch_dtype(self.train_config.dtype)
|
||||||
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
|
noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
|
||||||
|
if self.train_config.do_cfg or self.train_config.do_random_cfg:
|
||||||
|
# pick random negative prompts
|
||||||
|
if self.negative_prompt_pool is not None:
|
||||||
|
negative_prompts = []
|
||||||
|
for i in range(noisy_latents.shape[0]):
|
||||||
|
num_neg = random.randint(1, self.train_config.max_negative_prompts)
|
||||||
|
this_neg_prompts = [random.choice(self.negative_prompt_pool) for _ in range(num_neg)]
|
||||||
|
this_neg_prompt = ', '.join(this_neg_prompts)
|
||||||
|
negative_prompts.append(this_neg_prompt)
|
||||||
|
self.batch_negative_prompt = negative_prompts
|
||||||
|
else:
|
||||||
|
self.batch_negative_prompt = ['' for _ in range(batch.latents.shape[0])]
|
||||||
|
|
||||||
if self.adapter and isinstance(self.adapter, CustomAdapter):
|
if self.adapter and isinstance(self.adapter, CustomAdapter):
|
||||||
# condition the prompt
|
# condition the prompt
|
||||||
@@ -1030,7 +1055,8 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
if self.train_config.do_cfg:
|
if self.train_config.do_cfg:
|
||||||
# todo only do one and repeat it
|
# todo only do one and repeat it
|
||||||
unconditional_embeds = self.sd.encode_prompt(
|
unconditional_embeds = self.sd.encode_prompt(
|
||||||
["" for _ in range(noisy_latents.shape[0])],
|
self.batch_negative_prompt,
|
||||||
|
self.batch_negative_prompt,
|
||||||
dropout_prob=self.train_config.prompt_dropout_prob,
|
dropout_prob=self.train_config.prompt_dropout_prob,
|
||||||
long_prompts=self.do_long_prompts).to(
|
long_prompts=self.do_long_prompts).to(
|
||||||
self.device_torch,
|
self.device_torch,
|
||||||
@@ -1050,9 +1076,8 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
self.device_torch,
|
self.device_torch,
|
||||||
dtype=dtype)
|
dtype=dtype)
|
||||||
if self.train_config.do_cfg:
|
if self.train_config.do_cfg:
|
||||||
# todo only do one and repeat it
|
|
||||||
unconditional_embeds = self.sd.encode_prompt(
|
unconditional_embeds = self.sd.encode_prompt(
|
||||||
["" for _ in range(noisy_latents.shape[0])],
|
self.batch_negative_prompt,
|
||||||
dropout_prob=self.train_config.prompt_dropout_prob,
|
dropout_prob=self.train_config.prompt_dropout_prob,
|
||||||
long_prompts=self.do_long_prompts).to(
|
long_prompts=self.do_long_prompts).to(
|
||||||
self.device_torch,
|
self.device_torch,
|
||||||
|
|||||||
@@ -1,5 +1,9 @@
|
|||||||
import argparse
|
import argparse
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
sys.path.append(ROOT_DIR)
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|||||||
42
scripts/patch_te_adapter.py
Normal file
42
scripts/patch_te_adapter.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
import torch
|
||||||
|
from safetensors.torch import save_file, load_file
|
||||||
|
from collections import OrderedDict
|
||||||
|
meta = OrderedDict()
|
||||||
|
meta["format"] ="pt"
|
||||||
|
|
||||||
|
attn_dict = load_file("/mnt/Train/out/ip_adapter/sd15_bigG/sd15_bigG_000266000.safetensors")
|
||||||
|
state_dict = load_file("/home/jaret/Dev/models/hf/OstrisDiffusionV1/unet/diffusion_pytorch_model.safetensors")
|
||||||
|
|
||||||
|
attn_list = []
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
if "attn1" in key:
|
||||||
|
attn_list.append(key)
|
||||||
|
|
||||||
|
attn_names = ['down_blocks.0.attentions.0.transformer_blocks.0.attn2.processor', 'down_blocks.0.attentions.1.transformer_blocks.0.attn2.processor', 'down_blocks.1.attentions.0.transformer_blocks.0.attn2.processor', 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor', 'down_blocks.2.attentions.0.transformer_blocks.0.attn2.processor', 'down_blocks.2.attentions.1.transformer_blocks.0.attn2.processor', 'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor', 'up_blocks.1.attentions.1.transformer_blocks.0.attn2.processor', 'up_blocks.1.attentions.2.transformer_blocks.0.attn2.processor', 'up_blocks.2.attentions.0.transformer_blocks.0.attn2.processor', 'up_blocks.2.attentions.1.transformer_blocks.0.attn2.processor', 'up_blocks.2.attentions.2.transformer_blocks.0.attn2.processor', 'up_blocks.3.attentions.0.transformer_blocks.0.attn2.processor', 'up_blocks.3.attentions.1.transformer_blocks.0.attn2.processor', 'up_blocks.3.attentions.2.transformer_blocks.0.attn2.processor', 'mid_block.attentions.0.transformer_blocks.0.attn2.processor']
|
||||||
|
|
||||||
|
adapter_names = []
|
||||||
|
for i in range(100):
|
||||||
|
if f'te_adapter.adapter_modules.{i}.to_k_adapter.weight' in attn_dict:
|
||||||
|
adapter_names.append(f"te_adapter.adapter_modules.{i}.adapter")
|
||||||
|
|
||||||
|
|
||||||
|
for i in range(len(adapter_names)):
|
||||||
|
adapter_name = adapter_names[i]
|
||||||
|
attn_name = attn_names[i]
|
||||||
|
adapter_k_name = adapter_name[:-8] + '.to_k_adapter.weight'
|
||||||
|
adapter_v_name = adapter_name[:-8] + '.to_v_adapter.weight'
|
||||||
|
state_k_name = attn_name.replace(".processor", ".to_k.weight")
|
||||||
|
state_v_name = attn_name.replace(".processor", ".to_v.weight")
|
||||||
|
if adapter_k_name in attn_dict:
|
||||||
|
state_dict[state_k_name] = attn_dict[adapter_k_name]
|
||||||
|
state_dict[state_v_name] = attn_dict[adapter_v_name]
|
||||||
|
else:
|
||||||
|
print("adapter_k_name", adapter_k_name)
|
||||||
|
print("state_k_name", state_k_name)
|
||||||
|
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
state_dict[key] = value.cpu().to(torch.float16)
|
||||||
|
|
||||||
|
save_file(state_dict, "/home/jaret/Dev/models/hf/OstrisDiffusionV1/unet/diffusion_pytorch_model.safetensors", metadata=meta)
|
||||||
|
|
||||||
|
print("Done")
|
||||||
65
scripts/repair_dataset_folder.py
Normal file
65
scripts/repair_dataset_folder.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
import argparse
|
||||||
|
from PIL import Image
|
||||||
|
from PIL.ImageOps import exif_transpose
|
||||||
|
from tqdm import tqdm
|
||||||
|
import os
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description='Process some images.')
|
||||||
|
parser.add_argument("input_folder", type=str, help="Path to folder containing images")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
img_types = ['.jpg', '.jpeg', '.png', '.webp']
|
||||||
|
|
||||||
|
# find all images in the input folder
|
||||||
|
images = []
|
||||||
|
for root, _, files in os.walk(args.input_folder):
|
||||||
|
for file in files:
|
||||||
|
if file.lower().endswith(tuple(img_types)):
|
||||||
|
images.append(os.path.join(root, file))
|
||||||
|
print(f"Found {len(images)} images")
|
||||||
|
|
||||||
|
num_skipped = 0
|
||||||
|
num_repaired = 0
|
||||||
|
num_deleted = 0
|
||||||
|
|
||||||
|
pbar = tqdm(total=len(images), desc=f"Repaired {num_repaired} images", unit="image")
|
||||||
|
for img_path in images:
|
||||||
|
filename = os.path.basename(img_path)
|
||||||
|
filename_no_ext, file_extension = os.path.splitext(filename)
|
||||||
|
# if it is jpg, ignore
|
||||||
|
if file_extension.lower() == '.jpg':
|
||||||
|
num_skipped += 1
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
img = Image.open(img_path)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error opening {img_path}: {e}")
|
||||||
|
# delete it
|
||||||
|
os.remove(img_path)
|
||||||
|
num_deleted += 1
|
||||||
|
pbar.update(1)
|
||||||
|
pbar.set_description(f"Repaired {num_repaired} images, Skipped {num_skipped}, Deleted {num_deleted}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
img = exif_transpose(img)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error rotating {img_path}: {e}")
|
||||||
|
|
||||||
|
new_path = os.path.join(os.path.dirname(img_path), filename_no_ext + '.jpg')
|
||||||
|
|
||||||
|
img = img.convert("RGB")
|
||||||
|
img.save(new_path, quality=95)
|
||||||
|
# remove the old file
|
||||||
|
os.remove(img_path)
|
||||||
|
num_repaired += 1
|
||||||
|
pbar.update(1)
|
||||||
|
# update pbar
|
||||||
|
pbar.set_description(f"Repaired {num_repaired} images, Skipped {num_skipped}, Deleted {num_deleted}")
|
||||||
|
|
||||||
|
print("Done")
|
||||||
@@ -240,6 +240,7 @@ class TrainConfig:
|
|||||||
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
|
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
|
||||||
self.latent_multiplier = kwargs.get('latent_multiplier', 1.0)
|
self.latent_multiplier = kwargs.get('latent_multiplier', 1.0)
|
||||||
self.negative_prompt = kwargs.get('negative_prompt', None)
|
self.negative_prompt = kwargs.get('negative_prompt', None)
|
||||||
|
self.max_negative_prompts = kwargs.get('max_negative_prompts', 1)
|
||||||
# multiplier applied to loos on regularization images
|
# multiplier applied to loos on regularization images
|
||||||
self.reg_weight = kwargs.get('reg_weight', 1.0)
|
self.reg_weight = kwargs.get('reg_weight', 1.0)
|
||||||
self.num_train_timesteps = kwargs.get('num_train_timesteps', 1000)
|
self.num_train_timesteps = kwargs.get('num_train_timesteps', 1000)
|
||||||
|
|||||||
@@ -322,7 +322,7 @@ class IPAdapter(torch.nn.Module):
|
|||||||
elif adapter_config.type == 'ip+':
|
elif adapter_config.type == 'ip+':
|
||||||
heads = 12 if not sd.is_xl else 20
|
heads = 12 if not sd.is_xl else 20
|
||||||
dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280
|
dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280
|
||||||
embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch == "convnext" else \
|
embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch.startswith('convnext') else \
|
||||||
self.image_encoder.config.hidden_sizes[-1]
|
self.image_encoder.config.hidden_sizes[-1]
|
||||||
|
|
||||||
image_encoder_state_dict = self.image_encoder.state_dict()
|
image_encoder_state_dict = self.image_encoder.state_dict()
|
||||||
@@ -340,6 +340,10 @@ class IPAdapter(torch.nn.Module):
|
|||||||
dim = 4096
|
dim = 4096
|
||||||
output_dim = 4096
|
output_dim = 4096
|
||||||
|
|
||||||
|
if self.config.image_encoder_arch.startswith('convnext'):
|
||||||
|
in_tokens = 16 * 16
|
||||||
|
embedding_dim = self.image_encoder.config.hidden_sizes[-1]
|
||||||
|
|
||||||
# ip-adapter-plus
|
# ip-adapter-plus
|
||||||
image_proj_model = Resampler(
|
image_proj_model = Resampler(
|
||||||
dim=dim,
|
dim=dim,
|
||||||
@@ -406,6 +410,8 @@ class IPAdapter(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
attn_processor_keys = list(sd.unet.attn_processors.keys())
|
attn_processor_keys = list(sd.unet.attn_processors.keys())
|
||||||
|
|
||||||
|
attn_processor_names = []
|
||||||
|
|
||||||
for name in attn_processor_keys:
|
for name in attn_processor_keys:
|
||||||
cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") else \
|
cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") else \
|
||||||
sd.unet.config['cross_attention_dim']
|
sd.unet.config['cross_attention_dim']
|
||||||
@@ -446,6 +452,9 @@ class IPAdapter(torch.nn.Module):
|
|||||||
}
|
}
|
||||||
|
|
||||||
attn_procs[name].load_state_dict(weights)
|
attn_procs[name].load_state_dict(weights)
|
||||||
|
attn_processor_names.append(name)
|
||||||
|
print(f"Attn Processors")
|
||||||
|
print(attn_processor_names)
|
||||||
if self.sd_ref().is_pixart:
|
if self.sd_ref().is_pixart:
|
||||||
# we have to set them ourselves
|
# we have to set them ourselves
|
||||||
transformer: Transformer2DModel = sd.unet
|
transformer: Transformer2DModel = sd.unet
|
||||||
@@ -690,6 +699,12 @@ class IPAdapter(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
clip_image_embeds = clip_output.image_embeds
|
clip_image_embeds = clip_output.image_embeds
|
||||||
|
|
||||||
|
if self.config.image_encoder_arch.startswith('convnext'):
|
||||||
|
# flatten the width height layers to make the token space
|
||||||
|
clip_image_embeds = clip_image_embeds.view(clip_image_embeds.size(0), clip_image_embeds.size(1), -1)
|
||||||
|
# rearrange to (batch, tokens, size)
|
||||||
|
clip_image_embeds = clip_image_embeds.permute(0, 2, 1)
|
||||||
|
|
||||||
if self.config.quad_image:
|
if self.config.quad_image:
|
||||||
# get the outputs of the quat
|
# get the outputs of the quat
|
||||||
chunks = clip_image_embeds.chunk(quad_count, dim=0)
|
chunks = clip_image_embeds.chunk(quad_count, dim=0)
|
||||||
|
|||||||
@@ -171,11 +171,19 @@ class TEAdapter(torch.nn.Module):
|
|||||||
self.te_ref: weakref.ref = weakref.ref(te)
|
self.te_ref: weakref.ref = weakref.ref(te)
|
||||||
self.tokenizer_ref: weakref.ref = weakref.ref(tokenizer)
|
self.tokenizer_ref: weakref.ref = weakref.ref(tokenizer)
|
||||||
|
|
||||||
self.token_size = self.te_ref().config.d_model
|
if self.adapter_ref().config.text_encoder_arch == "t5":
|
||||||
|
self.token_size = self.te_ref().config.d_model
|
||||||
|
else:
|
||||||
|
self.token_size = self.te_ref().config.hidden_size
|
||||||
|
|
||||||
# init adapter modules
|
# init adapter modules
|
||||||
attn_procs = {}
|
attn_procs = {}
|
||||||
unet_sd = sd.unet.state_dict()
|
unet_sd = sd.unet.state_dict()
|
||||||
|
attn_dict_map = {
|
||||||
|
|
||||||
|
}
|
||||||
|
module_idx = 0
|
||||||
|
attn_processors_list = list(sd.unet.attn_processors.keys())
|
||||||
for name in sd.unet.attn_processors.keys():
|
for name in sd.unet.attn_processors.keys():
|
||||||
cross_attention_dim = None if name.endswith("attn1.processor") else sd.unet.config['cross_attention_dim']
|
cross_attention_dim = None if name.endswith("attn1.processor") else sd.unet.config['cross_attention_dim']
|
||||||
if name.startswith("mid_block"):
|
if name.startswith("mid_block"):
|
||||||
|
|||||||
@@ -287,7 +287,7 @@ class StableDiffusion:
|
|||||||
load_safety_checker=False,
|
load_safety_checker=False,
|
||||||
requires_safety_checker=False,
|
requires_safety_checker=False,
|
||||||
torch_dtype=self.torch_dtype,
|
torch_dtype=self.torch_dtype,
|
||||||
safety_checker=False,
|
safety_checker=None,
|
||||||
**load_args
|
**load_args
|
||||||
).to(self.device_torch)
|
).to(self.device_torch)
|
||||||
flush()
|
flush()
|
||||||
|
|||||||
Reference in New Issue
Block a user