From 13256135836c5dba00900b2d7d96adc5eb2038e2 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Thu, 29 Feb 2024 07:55:52 -0700 Subject: [PATCH] rework ilora --- toolkit/models/ilora.py | 48 +++++++++++++++++++---------------------- 1 file changed, 22 insertions(+), 26 deletions(-) diff --git a/toolkit/models/ilora.py b/toolkit/models/ilora.py index 9d18b6c4..a0c056c1 100644 --- a/toolkit/models/ilora.py +++ b/toolkit/models/ilora.py @@ -15,34 +15,23 @@ class InstantLoRAMidModule(torch.nn.Module): def __init__( self, dim: int, - vision_tokens: int, - vision_hidden_size: int, + index: int, lora_module: 'LoRAModule', instant_lora_module: 'InstantLoRAModule' ): super(InstantLoRAMidModule, self).__init__() self.dim = dim - self.vision_tokens = vision_tokens - self.vision_hidden_size = vision_hidden_size + self.index = index self.lora_module_ref = weakref.ref(lora_module) self.instant_lora_module_ref = weakref.ref(instant_lora_module) - self.zip = ZipperModule( - in_size=self.vision_hidden_size, - in_tokens=self.vision_tokens, - out_size=self.dim, - out_tokens=1, - hidden_size=self.dim, - hidden_tokens=self.vision_tokens - ) - def forward(self, x, *args, **kwargs): # get the vector img_embeds = self.instant_lora_module_ref().img_embeds # project it - scaler = self.zip(img_embeds) # (batch_size, 1, dim) + scaler = img_embeds[:, self.index, :] - # remove the channel dim + # remove the channel dim (index) scaler = scaler.squeeze(1) # double up if batch is 2x the size on x (cfg) @@ -84,23 +73,30 @@ class InstantLoRAModule(torch.nn.Module): # disable merging in. It is slower on inference self.sd_ref().network.can_merge_in = False - self.resampler = ZipperResampler( - in_size=self.vision_hidden_size, - in_tokens=self.vision_tokens, - out_size=self.vision_hidden_size, - out_tokens=self.vision_tokens, - hidden_size=self.vision_hidden_size, - hidden_tokens=self.vision_tokens - ) - self.ilora_modules = torch.nn.ModuleList() lora_modules = self.sd_ref().network.get_all_modules() - for lora_module in lora_modules: + # resample the output so each module gets one token with a size of its dim so we can multiply by that + self.resampler = ZipperResampler( + in_size=self.vision_hidden_size, + in_tokens=self.vision_tokens, + out_size=self.dim, + out_tokens=len(lora_modules), + hidden_size=self.vision_hidden_size, + hidden_tokens=self.vision_tokens, + num_blocks=1, + ) + + for idx, lora_module in enumerate(lora_modules): # add a new mid module that will take the original forward and add a vector to it # this will be used to add the vector to the original forward - mid_module = InstantLoRAMidModule(self.dim, self.vision_tokens, self.vision_hidden_size, lora_module, self) + mid_module = InstantLoRAMidModule( + self.dim, + idx, + lora_module, + self + ) self.ilora_modules.append(mid_module) # replace the LoRA lora_mid