From 1d3de678aaa265aa77ffaedecea3ca152e2d33a9 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Mon, 9 Oct 2023 06:21:49 -0600 Subject: [PATCH] fixed bug with trigger word embedding. Allow control images to load from the dataloader or legacy way --- extensions_built_in/sd_trainer/SDTrainer.py | 5 ++++- jobs/process/BaseSDTrainProcess.py | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index bb767786..9cc223ca 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -102,7 +102,10 @@ class SDTrainer(BaseSDTrainProcess): sigmas = None if self.adapter: # todo move this to data loader - adapter_images = self.get_adapter_images(batch) + if batch.control_tensor is not None: + adapter_images = batch.control_tensor.to(self.device_torch, dtype=dtype).detach() + else: + adapter_images = self.get_adapter_images(batch) # not 100% sure what this does. But they do it here # https://github.com/huggingface/diffusers/blob/38a664a3d61e27ab18cd698231422b3c38d6eebf/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170 # sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype) diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 9b3a5b5c..dd222bd2 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -231,8 +231,8 @@ class BaseSDTrainProcess(BaseTrainProcess): if self.trigger_word is not None: # just so auto1111 will pick it up o_dict['ss_tag_frequency'] = { - [self.trigger_word ]: { - [self.trigger_word ]: 1 + f"1_{self.trigger_word}": { + f"{self.trigger_word}": 1 } }