Bug fixes, negative prompting during training, hardened catching

This commit is contained in:
Jaret Burkett
2023-11-24 07:25:11 -07:00
parent fbec68681d
commit d7e55b6ad4
6 changed files with 93 additions and 9 deletions

View File

@@ -1,11 +1,12 @@
from collections import OrderedDict from collections import OrderedDict
from typing import Union from typing import Union, Literal, List
from diffusers import T2IAdapter from diffusers import T2IAdapter
from toolkit import train_tools from toolkit import train_tools
from toolkit.basic import value_map, adain from toolkit.basic import value_map, adain, get_mean_std
from toolkit.config_modules import GuidanceConfig from toolkit.config_modules import GuidanceConfig
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileItemDTO
from toolkit.image_utils import show_tensors, show_latents
from toolkit.ip_adapter import IPAdapter from toolkit.ip_adapter import IPAdapter
from toolkit.prompt_utils import PromptEmbeds from toolkit.prompt_utils import PromptEmbeds
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
@@ -309,7 +310,6 @@ class SDTrainer(BaseSDTrainProcess):
pass pass
def hook_train_loop(self, batch: 'DataLoaderBatchDTO'): def hook_train_loop(self, batch: 'DataLoaderBatchDTO'):
self.timer.start('preprocess_batch') self.timer.start('preprocess_batch')
batch = self.preprocess_batch(batch) batch = self.preprocess_batch(batch)
dtype = get_torch_dtype(self.train_config.dtype) dtype = get_torch_dtype(self.train_config.dtype)
@@ -322,6 +322,7 @@ class SDTrainer(BaseSDTrainProcess):
match_adapter_assist = False match_adapter_assist = False
# check if we are matching the adapter assistant # check if we are matching the adapter assistant
if self.assistant_adapter: if self.assistant_adapter:
if self.train_config.match_adapter_chance == 1.0: if self.train_config.match_adapter_chance == 1.0:
@@ -334,6 +335,12 @@ class SDTrainer(BaseSDTrainProcess):
self.timer.stop('preprocess_batch') self.timer.stop('preprocess_batch')
with torch.no_grad(): with torch.no_grad():
loss_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype)
for idx, file_item in enumerate(batch.file_items):
if file_item.is_reg:
loss_multiplier[idx] = loss_multiplier[idx] * self.train_config.reg_weight
adapter_images = None adapter_images = None
sigmas = None sigmas = None
if has_adapter_img and (self.adapter or self.assistant_adapter): if has_adapter_img and (self.adapter or self.assistant_adapter):
@@ -471,6 +478,12 @@ class SDTrainer(BaseSDTrainProcess):
mask_multiplier_list, mask_multiplier_list,
prompt_2_list prompt_2_list
): ):
if self.train_config.negative_prompt is not None:
# add negative prompt
conditioned_prompts = conditioned_prompts + [self.train_config.negative_prompt for x in
range(len(conditioned_prompts))]
if prompt_2 is not None:
prompt_2 = prompt_2 + [self.train_config.negative_prompt for x in range(len(prompt_2))]
with network: with network:
with self.timer('encode_prompt'): with self.timer('encode_prompt'):
@@ -585,6 +598,8 @@ class SDTrainer(BaseSDTrainProcess):
raise ValueError("loss is nan") raise ValueError("loss is nan")
with self.timer('backward'): with self.timer('backward'):
# todo we have multiplier seperated. works for now as res are not in same batch, but need to change
loss = loss * loss_multiplier.mean()
# IMPORTANT if gradient checkpointing do not leave with network when doing backward # IMPORTANT if gradient checkpointing do not leave with network when doing backward
# it will destroy the gradients. This is because the network is a context manager # it will destroy the gradients. This is because the network is a context manager
# and will change the multipliers back to 0.0 when exiting. They will be # and will change the multipliers back to 0.0 when exiting. They will be

View File

@@ -1259,6 +1259,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
################################################################### ###################################################################
start_step_num = self.step_num start_step_num = self.step_num
did_first_flush = False
for step in range(start_step_num, self.train_config.steps): for step in range(start_step_num, self.train_config.steps):
self.step_num = step self.step_num = step
# default to true so various things can turn it off # default to true so various things can turn it off
@@ -1332,6 +1333,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.timer.start('train_loop') self.timer.start('train_loop')
loss_dict = self.hook_train_loop(batch) loss_dict = self.hook_train_loop(batch)
self.timer.stop('train_loop') self.timer.stop('train_loop')
if not did_first_flush:
flush()
did_first_flush = True
# flush() # flush()
# setup the networks to gradient checkpointing and everything works # setup the networks to gradient checkpointing and everything works

View File

@@ -0,0 +1,57 @@
import argparse
from collections import OrderedDict
import torch
from toolkit.config_modules import ModelConfig
from toolkit.stable_diffusion_model import StableDiffusion
parser = argparse.ArgumentParser()
parser.add_argument(
'input_path',
type=str,
help='Path to original sdxl model'
)
parser.add_argument(
'output_path',
type=str,
help='output path'
)
parser.add_argument('--sdxl', action='store_true', help='is sdxl model')
parser.add_argument('--refiner', action='store_true', help='is refiner model')
parser.add_argument('--ssd', action='store_true', help='is ssd model')
parser.add_argument('--sd2', action='store_true', help='is sd 2 model')
args = parser.parse_args()
device = torch.device('cpu')
dtype = torch.float32
print(f"Loading model from {args.input_path}")
diffusers_model_config = ModelConfig(
name_or_path=args.input_path,
is_xl=args.sdxl,
is_v2=args.sd2,
is_ssd=args.ssd,
dtype=dtype,
)
diffusers_sd = StableDiffusion(
model_config=diffusers_model_config,
device=device,
dtype=dtype,
)
diffusers_sd.load_model()
print(f"Loaded model from {args.input_path}")
diffusers_sd.pipeline.fuse_lora()
meta = OrderedDict()
diffusers_sd.save(args.output_path, meta=meta)
print(f"Saved to {args.output_path}")

View File

@@ -193,6 +193,9 @@ class TrainConfig:
self.adapter_assist_name_or_path: Optional[str] = kwargs.get('adapter_assist_name_or_path', None) self.adapter_assist_name_or_path: Optional[str] = kwargs.get('adapter_assist_name_or_path', None)
self.noise_multiplier = kwargs.get('noise_multiplier', 1.0) self.noise_multiplier = kwargs.get('noise_multiplier', 1.0)
self.img_multiplier = kwargs.get('img_multiplier', 1.0) self.img_multiplier = kwargs.get('img_multiplier', 1.0)
self.negative_prompt = kwargs.get('negative_prompt', None)
# multiplier applied to loos on regularization images
self.reg_weight = kwargs.get('reg_weight', 1.0)
# dropout that happens before encoding. It functions independently per text encoder # dropout that happens before encoding. It functions independently per text encoder
self.prompt_dropout_prob = kwargs.get('prompt_dropout_prob', 0.0) self.prompt_dropout_prob = kwargs.get('prompt_dropout_prob', 0.0)

View File

@@ -77,6 +77,10 @@ def parse_metadata_from_safetensors(meta: OrderedDict) -> OrderedDict:
def load_metadata_from_safetensors(file_path: str) -> OrderedDict: def load_metadata_from_safetensors(file_path: str) -> OrderedDict:
with safe_open(file_path, framework="pt") as f: try:
metadata = f.metadata() with safe_open(file_path, framework="pt") as f:
return parse_metadata_from_safetensors(metadata) metadata = f.metadata()
return parse_metadata_from_safetensors(metadata)
except Exception as e:
print(f"Error loading metadata from {file_path}: {e}")
return OrderedDict()

View File

@@ -254,6 +254,7 @@ class ToolkitModuleMixin:
multiplier_batch_size = multiplier.size(0) multiplier_batch_size = multiplier.size(0)
if lora_output_batch_size != multiplier_batch_size: if lora_output_batch_size != multiplier_batch_size:
num_interleaves = lora_output_batch_size // multiplier_batch_size num_interleaves = lora_output_batch_size // multiplier_batch_size
# todo check if this is correct, do we just concat when doing cfg?
multiplier = multiplier.repeat_interleave(num_interleaves) multiplier = multiplier.repeat_interleave(num_interleaves)
x = org_forwarded + broadcast_and_multiply(lora_output, multiplier) x = org_forwarded + broadcast_and_multiply(lora_output, multiplier)
@@ -470,11 +471,11 @@ class ToolkitNetworkMixin:
self.torch_multiplier = tensor_multiplier.clone().detach() self.torch_multiplier = tensor_multiplier.clone().detach()
@property @property
def multiplier(self) -> Union[float, List[float]]: def multiplier(self) -> Union[float, List[float], List[List[float]]]:
return self._multiplier return self._multiplier
@multiplier.setter @multiplier.setter
def multiplier(self, value: Union[float, List[float]]): def multiplier(self, value: Union[float, List[float], List[List[float]]]):
# it takes time to update all the multipliers, so we only do it if the value has changed # it takes time to update all the multipliers, so we only do it if the value has changed
if self._multiplier == value: if self._multiplier == value:
return return