mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-03-02 09:09:48 +00:00
rework ilora
This commit is contained in:
@@ -15,34 +15,23 @@ class InstantLoRAMidModule(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
vision_tokens: int,
|
||||
vision_hidden_size: int,
|
||||
index: int,
|
||||
lora_module: 'LoRAModule',
|
||||
instant_lora_module: 'InstantLoRAModule'
|
||||
):
|
||||
super(InstantLoRAMidModule, self).__init__()
|
||||
self.dim = dim
|
||||
self.vision_tokens = vision_tokens
|
||||
self.vision_hidden_size = vision_hidden_size
|
||||
self.index = index
|
||||
self.lora_module_ref = weakref.ref(lora_module)
|
||||
self.instant_lora_module_ref = weakref.ref(instant_lora_module)
|
||||
|
||||
self.zip = ZipperModule(
|
||||
in_size=self.vision_hidden_size,
|
||||
in_tokens=self.vision_tokens,
|
||||
out_size=self.dim,
|
||||
out_tokens=1,
|
||||
hidden_size=self.dim,
|
||||
hidden_tokens=self.vision_tokens
|
||||
)
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
# get the vector
|
||||
img_embeds = self.instant_lora_module_ref().img_embeds
|
||||
# project it
|
||||
scaler = self.zip(img_embeds) # (batch_size, 1, dim)
|
||||
scaler = img_embeds[:, self.index, :]
|
||||
|
||||
# remove the channel dim
|
||||
# remove the channel dim (index)
|
||||
scaler = scaler.squeeze(1)
|
||||
|
||||
# double up if batch is 2x the size on x (cfg)
|
||||
@@ -84,23 +73,30 @@ class InstantLoRAModule(torch.nn.Module):
|
||||
# disable merging in. It is slower on inference
|
||||
self.sd_ref().network.can_merge_in = False
|
||||
|
||||
self.resampler = ZipperResampler(
|
||||
in_size=self.vision_hidden_size,
|
||||
in_tokens=self.vision_tokens,
|
||||
out_size=self.vision_hidden_size,
|
||||
out_tokens=self.vision_tokens,
|
||||
hidden_size=self.vision_hidden_size,
|
||||
hidden_tokens=self.vision_tokens
|
||||
)
|
||||
|
||||
self.ilora_modules = torch.nn.ModuleList()
|
||||
|
||||
lora_modules = self.sd_ref().network.get_all_modules()
|
||||
|
||||
for lora_module in lora_modules:
|
||||
# resample the output so each module gets one token with a size of its dim so we can multiply by that
|
||||
self.resampler = ZipperResampler(
|
||||
in_size=self.vision_hidden_size,
|
||||
in_tokens=self.vision_tokens,
|
||||
out_size=self.dim,
|
||||
out_tokens=len(lora_modules),
|
||||
hidden_size=self.vision_hidden_size,
|
||||
hidden_tokens=self.vision_tokens,
|
||||
num_blocks=1,
|
||||
)
|
||||
|
||||
for idx, lora_module in enumerate(lora_modules):
|
||||
# add a new mid module that will take the original forward and add a vector to it
|
||||
# this will be used to add the vector to the original forward
|
||||
mid_module = InstantLoRAMidModule(self.dim, self.vision_tokens, self.vision_hidden_size, lora_module, self)
|
||||
mid_module = InstantLoRAMidModule(
|
||||
self.dim,
|
||||
idx,
|
||||
lora_module,
|
||||
self
|
||||
)
|
||||
|
||||
self.ilora_modules.append(mid_module)
|
||||
# replace the LoRA lora_mid
|
||||
|
||||
Reference in New Issue
Block a user