fixed bug with trigger word embedding. Allow control images to load from the dataloader or legacy way

This commit is contained in:
Jaret Burkett
2023-10-09 06:21:49 -06:00
parent b1cfafa0c6
commit 1d3de678aa
2 changed files with 6 additions and 3 deletions

View File

@@ -102,7 +102,10 @@ class SDTrainer(BaseSDTrainProcess):
sigmas = None
if self.adapter:
# todo move this to data loader
adapter_images = self.get_adapter_images(batch)
if batch.control_tensor is not None:
adapter_images = batch.control_tensor.to(self.device_torch, dtype=dtype).detach()
else:
adapter_images = self.get_adapter_images(batch)
# not 100% sure what this does. But they do it here
# https://github.com/huggingface/diffusers/blob/38a664a3d61e27ab18cd698231422b3c38d6eebf/examples/t2i_adapter/train_t2i_adapter_sdxl.py#L1170
# sigmas = self.get_sigmas(timesteps, len(noisy_latents.shape), noisy_latents.dtype)

View File

@@ -231,8 +231,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
if self.trigger_word is not None:
# just so auto1111 will pick it up
o_dict['ss_tag_frequency'] = {
[self.trigger_word ]: {
[self.trigger_word ]: 1
f"1_{self.trigger_word}": {
f"{self.trigger_word}": 1
}
}