From 40f5c59da0f2a8ee29dac2f91b994753f04878ce Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sat, 31 Aug 2024 16:55:26 -0600 Subject: [PATCH] Fixes for training ilora on flux --- toolkit/models/ilora.py | 7 ++++++- toolkit/network_mixins.py | 1 + 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/toolkit/models/ilora.py b/toolkit/models/ilora.py index f292b69..33613ed 100644 --- a/toolkit/models/ilora.py +++ b/toolkit/models/ilora.py @@ -136,6 +136,8 @@ class InstantLoRAMidModule(torch.nn.Module): def down_forward(self, x, *args, **kwargs): # get the embed self.embed = self.instant_lora_module_ref().img_embeds[self.index] + if x.dtype != self.embed.dtype: + x = x.to(self.embed.dtype) down_size = math.prod(self.down_shape) down_weight = self.embed[:, :down_size] @@ -170,6 +172,8 @@ class InstantLoRAMidModule(torch.nn.Module): def up_forward(self, x, *args, **kwargs): self.embed = self.instant_lora_module_ref().img_embeds[self.index] + if x.dtype != self.embed.dtype: + x = x.to(self.embed.dtype) up_size = math.prod(self.up_shape) up_weight = self.embed[:, -up_size:] @@ -211,7 +215,8 @@ class InstantLoRAModule(torch.nn.Module): vision_tokens: int, head_dim: int, num_heads: int, # number of heads in the resampler - sd: 'StableDiffusion' + sd: 'StableDiffusion', + config=None ): super(InstantLoRAModule, self).__init__() # self.linear = torch.nn.Linear(2, 1) diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index e2c4fef..c567af5 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -175,6 +175,7 @@ class ToolkitModuleMixin: lx = self.lora_down(x) except RuntimeError as e: print(f"Error in {self.__class__.__name__} lora_down") + print(e) if isinstance(self.dropout, nn.Dropout) or isinstance(self.dropout, nn.Identity): lx = self.dropout(lx)