diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 49f39725..f33f6b32 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -1,11 +1,12 @@ from collections import OrderedDict -from typing import Union +from typing import Union, Literal, List from diffusers import T2IAdapter from toolkit import train_tools -from toolkit.basic import value_map, adain +from toolkit.basic import value_map, adain, get_mean_std from toolkit.config_modules import GuidanceConfig -from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO +from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO, FileItemDTO +from toolkit.image_utils import show_tensors, show_latents from toolkit.ip_adapter import IPAdapter from toolkit.prompt_utils import PromptEmbeds from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork @@ -309,7 +310,6 @@ class SDTrainer(BaseSDTrainProcess): pass def hook_train_loop(self, batch: 'DataLoaderBatchDTO'): - self.timer.start('preprocess_batch') batch = self.preprocess_batch(batch) dtype = get_torch_dtype(self.train_config.dtype) @@ -322,6 +322,7 @@ class SDTrainer(BaseSDTrainProcess): match_adapter_assist = False + # check if we are matching the adapter assistant if self.assistant_adapter: if self.train_config.match_adapter_chance == 1.0: @@ -334,6 +335,12 @@ class SDTrainer(BaseSDTrainProcess): self.timer.stop('preprocess_batch') with torch.no_grad(): + loss_multiplier = torch.ones((noisy_latents.shape[0], 1, 1, 1), device=self.device_torch, dtype=dtype) + for idx, file_item in enumerate(batch.file_items): + if file_item.is_reg: + loss_multiplier[idx] = loss_multiplier[idx] * self.train_config.reg_weight + + adapter_images = None sigmas = None if has_adapter_img and (self.adapter or self.assistant_adapter): @@ -471,6 +478,12 @@ class SDTrainer(BaseSDTrainProcess): mask_multiplier_list, prompt_2_list ): + if self.train_config.negative_prompt is not None: + # add negative prompt + conditioned_prompts = conditioned_prompts + [self.train_config.negative_prompt for x in + range(len(conditioned_prompts))] + if prompt_2 is not None: + prompt_2 = prompt_2 + [self.train_config.negative_prompt for x in range(len(prompt_2))] with network: with self.timer('encode_prompt'): @@ -585,6 +598,8 @@ class SDTrainer(BaseSDTrainProcess): raise ValueError("loss is nan") with self.timer('backward'): + # todo we have multiplier seperated. works for now as res are not in same batch, but need to change + loss = loss * loss_multiplier.mean() # IMPORTANT if gradient checkpointing do not leave with network when doing backward # it will destroy the gradients. This is because the network is a context manager # and will change the multipliers back to 0.0 when exiting. They will be diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index f81f5e19..09ec99e2 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -1259,6 +1259,7 @@ class BaseSDTrainProcess(BaseTrainProcess): ################################################################### start_step_num = self.step_num + did_first_flush = False for step in range(start_step_num, self.train_config.steps): self.step_num = step # default to true so various things can turn it off @@ -1332,6 +1333,9 @@ class BaseSDTrainProcess(BaseTrainProcess): self.timer.start('train_loop') loss_dict = self.hook_train_loop(batch) self.timer.stop('train_loop') + if not did_first_flush: + flush() + did_first_flush = True # flush() # setup the networks to gradient checkpointing and everything works diff --git a/scripts/make_diffusers_model.py b/scripts/make_diffusers_model.py new file mode 100644 index 00000000..1ec6c93a --- /dev/null +++ b/scripts/make_diffusers_model.py @@ -0,0 +1,57 @@ +import argparse +from collections import OrderedDict + +import torch + +from toolkit.config_modules import ModelConfig +from toolkit.stable_diffusion_model import StableDiffusion + + +parser = argparse.ArgumentParser() +parser.add_argument( + 'input_path', + type=str, + help='Path to original sdxl model' +) +parser.add_argument( + 'output_path', + type=str, + help='output path' +) +parser.add_argument('--sdxl', action='store_true', help='is sdxl model') +parser.add_argument('--refiner', action='store_true', help='is refiner model') +parser.add_argument('--ssd', action='store_true', help='is ssd model') +parser.add_argument('--sd2', action='store_true', help='is sd 2 model') + +args = parser.parse_args() +device = torch.device('cpu') +dtype = torch.float32 + +print(f"Loading model from {args.input_path}") + + +diffusers_model_config = ModelConfig( + name_or_path=args.input_path, + is_xl=args.sdxl, + is_v2=args.sd2, + is_ssd=args.ssd, + dtype=dtype, + ) +diffusers_sd = StableDiffusion( + model_config=diffusers_model_config, + device=device, + dtype=dtype, +) +diffusers_sd.load_model() + + +print(f"Loaded model from {args.input_path}") + +diffusers_sd.pipeline.fuse_lora() + +meta = OrderedDict() + +diffusers_sd.save(args.output_path, meta=meta) + + +print(f"Saved to {args.output_path}") diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index fb69c2fc..b52eddfe 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -193,6 +193,9 @@ class TrainConfig: self.adapter_assist_name_or_path: Optional[str] = kwargs.get('adapter_assist_name_or_path', None) self.noise_multiplier = kwargs.get('noise_multiplier', 1.0) self.img_multiplier = kwargs.get('img_multiplier', 1.0) + self.negative_prompt = kwargs.get('negative_prompt', None) + # multiplier applied to loos on regularization images + self.reg_weight = kwargs.get('reg_weight', 1.0) # dropout that happens before encoding. It functions independently per text encoder self.prompt_dropout_prob = kwargs.get('prompt_dropout_prob', 0.0) diff --git a/toolkit/metadata.py b/toolkit/metadata.py index b652dcef..bf969f09 100644 --- a/toolkit/metadata.py +++ b/toolkit/metadata.py @@ -77,6 +77,10 @@ def parse_metadata_from_safetensors(meta: OrderedDict) -> OrderedDict: def load_metadata_from_safetensors(file_path: str) -> OrderedDict: - with safe_open(file_path, framework="pt") as f: - metadata = f.metadata() - return parse_metadata_from_safetensors(metadata) + try: + with safe_open(file_path, framework="pt") as f: + metadata = f.metadata() + return parse_metadata_from_safetensors(metadata) + except Exception as e: + print(f"Error loading metadata from {file_path}: {e}") + return OrderedDict() diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index 172925d2..0644f494 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -254,6 +254,7 @@ class ToolkitModuleMixin: multiplier_batch_size = multiplier.size(0) if lora_output_batch_size != multiplier_batch_size: num_interleaves = lora_output_batch_size // multiplier_batch_size + # todo check if this is correct, do we just concat when doing cfg? multiplier = multiplier.repeat_interleave(num_interleaves) x = org_forwarded + broadcast_and_multiply(lora_output, multiplier) @@ -470,11 +471,11 @@ class ToolkitNetworkMixin: self.torch_multiplier = tensor_multiplier.clone().detach() @property - def multiplier(self) -> Union[float, List[float]]: + def multiplier(self) -> Union[float, List[float], List[List[float]]]: return self._multiplier @multiplier.setter - def multiplier(self, value: Union[float, List[float]]): + def multiplier(self, value: Union[float, List[float], List[List[float]]]): # it takes time to update all the multipliers, so we only do it if the value has changed if self._multiplier == value: return