mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Bug fixes, negative prompting during training, hardened catching
This commit is contained in:
@@ -1,11 +1,12 @@
|
|||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Union
|
from typing import Union, Literal, List
|
||||||
from diffusers import T2IAdapter
|
from diffusers import T2IAdapter
|
||||||
|
|
||||||
from toolkit import train_tools
|
from toolkit import train_tools
|
||||||
from toolkit.basic import value_map, adain
|
from toolkit.basic import value_map, adain, get_mean_std
|
||||||
from toolkit.config_modules import GuidanceConfig
|
from toolkit.config_modules import GuidanceConfig
|
||||||
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
|
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileItemDTO
|
||||||
|
from toolkit.image_utils import show_tensors, show_latents
|
||||||
from toolkit.ip_adapter import IPAdapter
|
from toolkit.ip_adapter import IPAdapter
|
||||||
from toolkit.prompt_utils import PromptEmbeds
|
from toolkit.prompt_utils import PromptEmbeds
|
||||||
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
|
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
|
||||||
@@ -309,7 +310,6 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def hook_train_loop(self, batch: 'DataLoaderBatchDTO'):
|
def hook_train_loop(self, batch: 'DataLoaderBatchDTO'):
|
||||||
|
|
||||||
self.timer.start('preprocess_batch')
|
self.timer.start('preprocess_batch')
|
||||||
batch = self.preprocess_batch(batch)
|
batch = self.preprocess_batch(batch)
|
||||||
dtype = get_torch_dtype(self.train_config.dtype)
|
dtype = get_torch_dtype(self.train_config.dtype)
|
||||||
@@ -322,6 +322,7 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
|
|
||||||
match_adapter_assist = False
|
match_adapter_assist = False
|
||||||
|
|
||||||
|
|
||||||
# check if we are matching the adapter assistant
|
# check if we are matching the adapter assistant
|
||||||
if self.assistant_adapter:
|
if self.assistant_adapter:
|
||||||
if self.train_config.match_adapter_chance == 1.0:
|
if self.train_config.match_adapter_chance == 1.0:
|
||||||
@@ -334,6 +335,12 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
self.timer.stop('preprocess_batch')
|
self.timer.stop('preprocess_batch')
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
loss_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype)
|
||||||
|
for idx, file_item in enumerate(batch.file_items):
|
||||||
|
if file_item.is_reg:
|
||||||
|
loss_multiplier[idx] = loss_multiplier[idx] * self.train_config.reg_weight
|
||||||
|
|
||||||
|
|
||||||
adapter_images = None
|
adapter_images = None
|
||||||
sigmas = None
|
sigmas = None
|
||||||
if has_adapter_img and (self.adapter or self.assistant_adapter):
|
if has_adapter_img and (self.adapter or self.assistant_adapter):
|
||||||
@@ -471,6 +478,12 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
mask_multiplier_list,
|
mask_multiplier_list,
|
||||||
prompt_2_list
|
prompt_2_list
|
||||||
):
|
):
|
||||||
|
if self.train_config.negative_prompt is not None:
|
||||||
|
# add negative prompt
|
||||||
|
conditioned_prompts = conditioned_prompts + [self.train_config.negative_prompt for x in
|
||||||
|
range(len(conditioned_prompts))]
|
||||||
|
if prompt_2 is not None:
|
||||||
|
prompt_2 = prompt_2 + [self.train_config.negative_prompt for x in range(len(prompt_2))]
|
||||||
|
|
||||||
with network:
|
with network:
|
||||||
with self.timer('encode_prompt'):
|
with self.timer('encode_prompt'):
|
||||||
@@ -585,6 +598,8 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
raise ValueError("loss is nan")
|
raise ValueError("loss is nan")
|
||||||
|
|
||||||
with self.timer('backward'):
|
with self.timer('backward'):
|
||||||
|
# todo we have multiplier seperated. works for now as res are not in same batch, but need to change
|
||||||
|
loss = loss * loss_multiplier.mean()
|
||||||
# IMPORTANT if gradient checkpointing do not leave with network when doing backward
|
# IMPORTANT if gradient checkpointing do not leave with network when doing backward
|
||||||
# it will destroy the gradients. This is because the network is a context manager
|
# it will destroy the gradients. This is because the network is a context manager
|
||||||
# and will change the multipliers back to 0.0 when exiting. They will be
|
# and will change the multipliers back to 0.0 when exiting. They will be
|
||||||
|
|||||||
@@ -1259,6 +1259,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
###################################################################
|
###################################################################
|
||||||
|
|
||||||
start_step_num = self.step_num
|
start_step_num = self.step_num
|
||||||
|
did_first_flush = False
|
||||||
for step in range(start_step_num, self.train_config.steps):
|
for step in range(start_step_num, self.train_config.steps):
|
||||||
self.step_num = step
|
self.step_num = step
|
||||||
# default to true so various things can turn it off
|
# default to true so various things can turn it off
|
||||||
@@ -1332,6 +1333,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
self.timer.start('train_loop')
|
self.timer.start('train_loop')
|
||||||
loss_dict = self.hook_train_loop(batch)
|
loss_dict = self.hook_train_loop(batch)
|
||||||
self.timer.stop('train_loop')
|
self.timer.stop('train_loop')
|
||||||
|
if not did_first_flush:
|
||||||
|
flush()
|
||||||
|
did_first_flush = True
|
||||||
# flush()
|
# flush()
|
||||||
# setup the networks to gradient checkpointing and everything works
|
# setup the networks to gradient checkpointing and everything works
|
||||||
|
|
||||||
|
|||||||
57
scripts/make_diffusers_model.py
Normal file
57
scripts/make_diffusers_model.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
import argparse
|
||||||
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from toolkit.config_modules import ModelConfig
|
||||||
|
from toolkit.stable_diffusion_model import StableDiffusion
|
||||||
|
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
'input_path',
|
||||||
|
type=str,
|
||||||
|
help='Path to original sdxl model'
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'output_path',
|
||||||
|
type=str,
|
||||||
|
help='output path'
|
||||||
|
)
|
||||||
|
parser.add_argument('--sdxl', action='store_true', help='is sdxl model')
|
||||||
|
parser.add_argument('--refiner', action='store_true', help='is refiner model')
|
||||||
|
parser.add_argument('--ssd', action='store_true', help='is ssd model')
|
||||||
|
parser.add_argument('--sd2', action='store_true', help='is sd 2 model')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
device = torch.device('cpu')
|
||||||
|
dtype = torch.float32
|
||||||
|
|
||||||
|
print(f"Loading model from {args.input_path}")
|
||||||
|
|
||||||
|
|
||||||
|
diffusers_model_config = ModelConfig(
|
||||||
|
name_or_path=args.input_path,
|
||||||
|
is_xl=args.sdxl,
|
||||||
|
is_v2=args.sd2,
|
||||||
|
is_ssd=args.ssd,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
diffusers_sd = StableDiffusion(
|
||||||
|
model_config=diffusers_model_config,
|
||||||
|
device=device,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
diffusers_sd.load_model()
|
||||||
|
|
||||||
|
|
||||||
|
print(f"Loaded model from {args.input_path}")
|
||||||
|
|
||||||
|
diffusers_sd.pipeline.fuse_lora()
|
||||||
|
|
||||||
|
meta = OrderedDict()
|
||||||
|
|
||||||
|
diffusers_sd.save(args.output_path, meta=meta)
|
||||||
|
|
||||||
|
|
||||||
|
print(f"Saved to {args.output_path}")
|
||||||
@@ -193,6 +193,9 @@ class TrainConfig:
|
|||||||
self.adapter_assist_name_or_path: Optional[str] = kwargs.get('adapter_assist_name_or_path', None)
|
self.adapter_assist_name_or_path: Optional[str] = kwargs.get('adapter_assist_name_or_path', None)
|
||||||
self.noise_multiplier = kwargs.get('noise_multiplier', 1.0)
|
self.noise_multiplier = kwargs.get('noise_multiplier', 1.0)
|
||||||
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
|
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
|
||||||
|
self.negative_prompt = kwargs.get('negative_prompt', None)
|
||||||
|
# multiplier applied to loos on regularization images
|
||||||
|
self.reg_weight = kwargs.get('reg_weight', 1.0)
|
||||||
|
|
||||||
# dropout that happens before encoding. It functions independently per text encoder
|
# dropout that happens before encoding. It functions independently per text encoder
|
||||||
self.prompt_dropout_prob = kwargs.get('prompt_dropout_prob', 0.0)
|
self.prompt_dropout_prob = kwargs.get('prompt_dropout_prob', 0.0)
|
||||||
|
|||||||
@@ -77,6 +77,10 @@ def parse_metadata_from_safetensors(meta: OrderedDict) -> OrderedDict:
|
|||||||
|
|
||||||
|
|
||||||
def load_metadata_from_safetensors(file_path: str) -> OrderedDict:
|
def load_metadata_from_safetensors(file_path: str) -> OrderedDict:
|
||||||
with safe_open(file_path, framework="pt") as f:
|
try:
|
||||||
metadata = f.metadata()
|
with safe_open(file_path, framework="pt") as f:
|
||||||
return parse_metadata_from_safetensors(metadata)
|
metadata = f.metadata()
|
||||||
|
return parse_metadata_from_safetensors(metadata)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error loading metadata from {file_path}: {e}")
|
||||||
|
return OrderedDict()
|
||||||
|
|||||||
@@ -254,6 +254,7 @@ class ToolkitModuleMixin:
|
|||||||
multiplier_batch_size = multiplier.size(0)
|
multiplier_batch_size = multiplier.size(0)
|
||||||
if lora_output_batch_size != multiplier_batch_size:
|
if lora_output_batch_size != multiplier_batch_size:
|
||||||
num_interleaves = lora_output_batch_size // multiplier_batch_size
|
num_interleaves = lora_output_batch_size // multiplier_batch_size
|
||||||
|
# todo check if this is correct, do we just concat when doing cfg?
|
||||||
multiplier = multiplier.repeat_interleave(num_interleaves)
|
multiplier = multiplier.repeat_interleave(num_interleaves)
|
||||||
|
|
||||||
x = org_forwarded + broadcast_and_multiply(lora_output, multiplier)
|
x = org_forwarded + broadcast_and_multiply(lora_output, multiplier)
|
||||||
@@ -470,11 +471,11 @@ class ToolkitNetworkMixin:
|
|||||||
self.torch_multiplier = tensor_multiplier.clone().detach()
|
self.torch_multiplier = tensor_multiplier.clone().detach()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def multiplier(self) -> Union[float, List[float]]:
|
def multiplier(self) -> Union[float, List[float], List[List[float]]]:
|
||||||
return self._multiplier
|
return self._multiplier
|
||||||
|
|
||||||
@multiplier.setter
|
@multiplier.setter
|
||||||
def multiplier(self, value: Union[float, List[float]]):
|
def multiplier(self, value: Union[float, List[float], List[List[float]]]):
|
||||||
# it takes time to update all the multipliers, so we only do it if the value has changed
|
# it takes time to update all the multipliers, so we only do it if the value has changed
|
||||||
if self._multiplier == value:
|
if self._multiplier == value:
|
||||||
return
|
return
|
||||||
|
|||||||
Reference in New Issue
Block a user