mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-05 13:09:57 +00:00
Fixes for training ilora on flux
This commit is contained in:
@@ -136,6 +136,8 @@ class InstantLoRAMidModule(torch.nn.Module):
|
||||
def down_forward(self, x, *args, **kwargs):
|
||||
# get the embed
|
||||
self.embed = self.instant_lora_module_ref().img_embeds[self.index]
|
||||
if x.dtype != self.embed.dtype:
|
||||
x = x.to(self.embed.dtype)
|
||||
down_size = math.prod(self.down_shape)
|
||||
down_weight = self.embed[:, :down_size]
|
||||
|
||||
@@ -170,6 +172,8 @@ class InstantLoRAMidModule(torch.nn.Module):
|
||||
|
||||
def up_forward(self, x, *args, **kwargs):
|
||||
self.embed = self.instant_lora_module_ref().img_embeds[self.index]
|
||||
if x.dtype != self.embed.dtype:
|
||||
x = x.to(self.embed.dtype)
|
||||
up_size = math.prod(self.up_shape)
|
||||
up_weight = self.embed[:, -up_size:]
|
||||
|
||||
@@ -211,7 +215,8 @@ class InstantLoRAModule(torch.nn.Module):
|
||||
vision_tokens: int,
|
||||
head_dim: int,
|
||||
num_heads: int, # number of heads in the resampler
|
||||
sd: 'StableDiffusion'
|
||||
sd: 'StableDiffusion',
|
||||
config=None
|
||||
):
|
||||
super(InstantLoRAModule, self).__init__()
|
||||
# self.linear = torch.nn.Linear(2, 1)
|
||||
|
||||
@@ -175,6 +175,7 @@ class ToolkitModuleMixin:
|
||||
lx = self.lora_down(x)
|
||||
except RuntimeError as e:
|
||||
print(f"Error in {self.__class__.__name__} lora_down")
|
||||
print(e)
|
||||
|
||||
if isinstance(self.dropout, nn.Dropout) or isinstance(self.dropout, nn.Identity):
|
||||
lx = self.dropout(lx)
|
||||
|
||||
Reference in New Issue
Block a user