Fixes for training ilora on flux

This commit is contained in:
Jaret Burkett
2024-08-31 16:55:26 -06:00
parent 3e71a99df0
commit 40f5c59da0
2 changed files with 7 additions and 1 deletions

View File

@@ -136,6 +136,8 @@ class InstantLoRAMidModule(torch.nn.Module):
def down_forward(self, x, *args, **kwargs):
# get the embed
self.embed = self.instant_lora_module_ref().img_embeds[self.index]
if x.dtype != self.embed.dtype:
x = x.to(self.embed.dtype)
down_size = math.prod(self.down_shape)
down_weight = self.embed[:, :down_size]
@@ -170,6 +172,8 @@ class InstantLoRAMidModule(torch.nn.Module):
def up_forward(self, x, *args, **kwargs):
self.embed = self.instant_lora_module_ref().img_embeds[self.index]
if x.dtype != self.embed.dtype:
x = x.to(self.embed.dtype)
up_size = math.prod(self.up_shape)
up_weight = self.embed[:, -up_size:]
@@ -211,7 +215,8 @@ class InstantLoRAModule(torch.nn.Module):
vision_tokens: int,
head_dim: int,
num_heads: int, # number of heads in the resampler
sd: 'StableDiffusion'
sd: 'StableDiffusion',
config=None
):
super(InstantLoRAModule, self).__init__()
# self.linear = torch.nn.Linear(2, 1)

View File

@@ -175,6 +175,7 @@ class ToolkitModuleMixin:
lx = self.lora_down(x)
except RuntimeError as e:
print(f"Error in {self.__class__.__name__} lora_down")
print(e)
if isinstance(self.dropout, nn.Dropout) or isinstance(self.dropout, nn.Identity):
lx = self.dropout(lx)