mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Bug fixes, negative prompting during training, hardened catching
This commit is contained in:
@@ -1,11 +1,12 @@
|
||||
from collections import OrderedDict
|
||||
from typing import Union
|
||||
from typing import Union, Literal, List
|
||||
from diffusers import T2IAdapter
|
||||
|
||||
from toolkit import train_tools
|
||||
from toolkit.basic import value_map, adain
|
||||
from toolkit.basic import value_map, adain, get_mean_std
|
||||
from toolkit.config_modules import GuidanceConfig
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileItemDTO
|
||||
from toolkit.image_utils import show_tensors, show_latents
|
||||
from toolkit.ip_adapter import IPAdapter
|
||||
from toolkit.prompt_utils import PromptEmbeds
|
||||
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
|
||||
@@ -309,7 +310,6 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
pass
|
||||
|
||||
def hook_train_loop(self, batch: 'DataLoaderBatchDTO'):
|
||||
|
||||
self.timer.start('preprocess_batch')
|
||||
batch = self.preprocess_batch(batch)
|
||||
dtype = get_torch_dtype(self.train_config.dtype)
|
||||
@@ -322,6 +322,7 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
match_adapter_assist = False
|
||||
|
||||
|
||||
# check if we are matching the adapter assistant
|
||||
if self.assistant_adapter:
|
||||
if self.train_config.match_adapter_chance == 1.0:
|
||||
@@ -334,6 +335,12 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
self.timer.stop('preprocess_batch')
|
||||
|
||||
with torch.no_grad():
|
||||
loss_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype)
|
||||
for idx, file_item in enumerate(batch.file_items):
|
||||
if file_item.is_reg:
|
||||
loss_multiplier[idx] = loss_multiplier[idx] * self.train_config.reg_weight
|
||||
|
||||
|
||||
adapter_images = None
|
||||
sigmas = None
|
||||
if has_adapter_img and (self.adapter or self.assistant_adapter):
|
||||
@@ -471,6 +478,12 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
mask_multiplier_list,
|
||||
prompt_2_list
|
||||
):
|
||||
if self.train_config.negative_prompt is not None:
|
||||
# add negative prompt
|
||||
conditioned_prompts = conditioned_prompts + [self.train_config.negative_prompt for x in
|
||||
range(len(conditioned_prompts))]
|
||||
if prompt_2 is not None:
|
||||
prompt_2 = prompt_2 + [self.train_config.negative_prompt for x in range(len(prompt_2))]
|
||||
|
||||
with network:
|
||||
with self.timer('encode_prompt'):
|
||||
@@ -585,6 +598,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
raise ValueError("loss is nan")
|
||||
|
||||
with self.timer('backward'):
|
||||
# todo we have multiplier seperated. works for now as res are not in same batch, but need to change
|
||||
loss = loss * loss_multiplier.mean()
|
||||
# IMPORTANT if gradient checkpointing do not leave with network when doing backward
|
||||
# it will destroy the gradients. This is because the network is a context manager
|
||||
# and will change the multipliers back to 0.0 when exiting. They will be
|
||||
|
||||
@@ -1259,6 +1259,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
###################################################################
|
||||
|
||||
start_step_num = self.step_num
|
||||
did_first_flush = False
|
||||
for step in range(start_step_num, self.train_config.steps):
|
||||
self.step_num = step
|
||||
# default to true so various things can turn it off
|
||||
@@ -1332,6 +1333,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.timer.start('train_loop')
|
||||
loss_dict = self.hook_train_loop(batch)
|
||||
self.timer.stop('train_loop')
|
||||
if not did_first_flush:
|
||||
flush()
|
||||
did_first_flush = True
|
||||
# flush()
|
||||
# setup the networks to gradient checkpointing and everything works
|
||||
|
||||
|
||||
57
scripts/make_diffusers_model.py
Normal file
57
scripts/make_diffusers_model.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import argparse
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
|
||||
from toolkit.config_modules import ModelConfig
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'input_path',
|
||||
type=str,
|
||||
help='Path to original sdxl model'
|
||||
)
|
||||
parser.add_argument(
|
||||
'output_path',
|
||||
type=str,
|
||||
help='output path'
|
||||
)
|
||||
parser.add_argument('--sdxl', action='store_true', help='is sdxl model')
|
||||
parser.add_argument('--refiner', action='store_true', help='is refiner model')
|
||||
parser.add_argument('--ssd', action='store_true', help='is ssd model')
|
||||
parser.add_argument('--sd2', action='store_true', help='is sd 2 model')
|
||||
|
||||
args = parser.parse_args()
|
||||
device = torch.device('cpu')
|
||||
dtype = torch.float32
|
||||
|
||||
print(f"Loading model from {args.input_path}")
|
||||
|
||||
|
||||
diffusers_model_config = ModelConfig(
|
||||
name_or_path=args.input_path,
|
||||
is_xl=args.sdxl,
|
||||
is_v2=args.sd2,
|
||||
is_ssd=args.ssd,
|
||||
dtype=dtype,
|
||||
)
|
||||
diffusers_sd = StableDiffusion(
|
||||
model_config=diffusers_model_config,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
diffusers_sd.load_model()
|
||||
|
||||
|
||||
print(f"Loaded model from {args.input_path}")
|
||||
|
||||
diffusers_sd.pipeline.fuse_lora()
|
||||
|
||||
meta = OrderedDict()
|
||||
|
||||
diffusers_sd.save(args.output_path, meta=meta)
|
||||
|
||||
|
||||
print(f"Saved to {args.output_path}")
|
||||
@@ -193,6 +193,9 @@ class TrainConfig:
|
||||
self.adapter_assist_name_or_path: Optional[str] = kwargs.get('adapter_assist_name_or_path', None)
|
||||
self.noise_multiplier = kwargs.get('noise_multiplier', 1.0)
|
||||
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
|
||||
self.negative_prompt = kwargs.get('negative_prompt', None)
|
||||
# multiplier applied to loos on regularization images
|
||||
self.reg_weight = kwargs.get('reg_weight', 1.0)
|
||||
|
||||
# dropout that happens before encoding. It functions independently per text encoder
|
||||
self.prompt_dropout_prob = kwargs.get('prompt_dropout_prob', 0.0)
|
||||
|
||||
@@ -77,6 +77,10 @@ def parse_metadata_from_safetensors(meta: OrderedDict) -> OrderedDict:
|
||||
|
||||
|
||||
def load_metadata_from_safetensors(file_path: str) -> OrderedDict:
|
||||
try:
|
||||
with safe_open(file_path, framework="pt") as f:
|
||||
metadata = f.metadata()
|
||||
return parse_metadata_from_safetensors(metadata)
|
||||
except Exception as e:
|
||||
print(f"Error loading metadata from {file_path}: {e}")
|
||||
return OrderedDict()
|
||||
|
||||
@@ -254,6 +254,7 @@ class ToolkitModuleMixin:
|
||||
multiplier_batch_size = multiplier.size(0)
|
||||
if lora_output_batch_size != multiplier_batch_size:
|
||||
num_interleaves = lora_output_batch_size // multiplier_batch_size
|
||||
# todo check if this is correct, do we just concat when doing cfg?
|
||||
multiplier = multiplier.repeat_interleave(num_interleaves)
|
||||
|
||||
x = org_forwarded + broadcast_and_multiply(lora_output, multiplier)
|
||||
@@ -470,11 +471,11 @@ class ToolkitNetworkMixin:
|
||||
self.torch_multiplier = tensor_multiplier.clone().detach()
|
||||
|
||||
@property
|
||||
def multiplier(self) -> Union[float, List[float]]:
|
||||
def multiplier(self) -> Union[float, List[float], List[List[float]]]:
|
||||
return self._multiplier
|
||||
|
||||
@multiplier.setter
|
||||
def multiplier(self, value: Union[float, List[float]]):
|
||||
def multiplier(self, value: Union[float, List[float], List[List[float]]]):
|
||||
# it takes time to update all the multipliers, so we only do it if the value has changed
|
||||
if self._multiplier == value:
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user