rework ilora

This commit is contained in:
Jaret Burkett
2024-02-29 07:55:52 -07:00
parent 337945de9a
commit 1325613583

View File

@@ -15,34 +15,23 @@ class InstantLoRAMidModule(torch.nn.Module):
def __init__(
self,
dim: int,
vision_tokens: int,
vision_hidden_size: int,
index: int,
lora_module: 'LoRAModule',
instant_lora_module: 'InstantLoRAModule'
):
super(InstantLoRAMidModule, self).__init__()
self.dim = dim
self.vision_tokens = vision_tokens
self.vision_hidden_size = vision_hidden_size
self.index = index
self.lora_module_ref = weakref.ref(lora_module)
self.instant_lora_module_ref = weakref.ref(instant_lora_module)
self.zip = ZipperModule(
in_size=self.vision_hidden_size,
in_tokens=self.vision_tokens,
out_size=self.dim,
out_tokens=1,
hidden_size=self.dim,
hidden_tokens=self.vision_tokens
)
def forward(self, x, *args, **kwargs):
# get the vector
img_embeds = self.instant_lora_module_ref().img_embeds
# project it
scaler = self.zip(img_embeds) # (batch_size, 1, dim)
scaler = img_embeds[:, self.index, :]
# remove the channel dim
# remove the channel dim (index)
scaler = scaler.squeeze(1)
# double up if batch is 2x the size on x (cfg)
@@ -84,23 +73,30 @@ class InstantLoRAModule(torch.nn.Module):
# disable merging in. It is slower on inference
self.sd_ref().network.can_merge_in = False
self.resampler = ZipperResampler(
in_size=self.vision_hidden_size,
in_tokens=self.vision_tokens,
out_size=self.vision_hidden_size,
out_tokens=self.vision_tokens,
hidden_size=self.vision_hidden_size,
hidden_tokens=self.vision_tokens
)
self.ilora_modules = torch.nn.ModuleList()
lora_modules = self.sd_ref().network.get_all_modules()
for lora_module in lora_modules:
# resample the output so each module gets one token with a size of its dim so we can multiply by that
self.resampler = ZipperResampler(
in_size=self.vision_hidden_size,
in_tokens=self.vision_tokens,
out_size=self.dim,
out_tokens=len(lora_modules),
hidden_size=self.vision_hidden_size,
hidden_tokens=self.vision_tokens,
num_blocks=1,
)
for idx, lora_module in enumerate(lora_modules):
# add a new mid module that will take the original forward and add a vector to it
# this will be used to add the vector to the original forward
mid_module = InstantLoRAMidModule(self.dim, self.vision_tokens, self.vision_hidden_size, lora_module, self)
mid_module = InstantLoRAMidModule(
self.dim,
idx,
lora_module,
self
)
self.ilora_modules.append(mid_module)
# replace the LoRA lora_mid