WIP Ilora

This commit is contained in:
Jaret Burkett
2024-06-14 09:31:01 -06:00
parent bd10d2d668
commit 37cebd9458
6 changed files with 57 additions and 29 deletions

View File

@@ -181,6 +181,7 @@ class InstantLoRAModule(torch.nn.Module):
vision_hidden_size: int,
vision_tokens: int,
head_dim: int,
num_heads: int, # number of heads in the resampler
sd: 'StableDiffusion'
):
super(InstantLoRAModule, self).__init__()
@@ -190,6 +191,7 @@ class InstantLoRAModule(torch.nn.Module):
self.vision_hidden_size = vision_hidden_size
self.vision_tokens = vision_tokens
self.head_dim = head_dim
self.num_heads = num_heads
# stores the projection vector. Grabbed by modules
self.img_embeds: List[torch.Tensor] = None
@@ -243,7 +245,7 @@ class InstantLoRAModule(torch.nn.Module):
depth=4,
dim_head=64,
heads=12,
num_queries=1, # output tokens
num_queries=num_heads, # output tokens
embedding_dim=vision_hidden_size,
max_seq_len=vision_tokens,
output_dim=head_dim,
@@ -261,25 +263,26 @@ class InstantLoRAModule(torch.nn.Module):
self.migrate_weight_mapping()
def migrate_weight_mapping(self):
# changes the names of the modules to common ones
keymap = self.sd_ref().network.get_keymap()
save_keymap = {}
if keymap is not None:
for ldm_key, diffusers_key in keymap.items():
# invert them
save_keymap[diffusers_key] = ldm_key
new_keymap = {}
for key, value in self.weight_mapping:
if key in save_keymap:
new_keymap[save_keymap[key]] = value
else:
print(f"Key {key} not found in keymap")
new_keymap[key] = value
self.weight_mapping = new_keymap
else:
print("No keymap found. Using default names")
return
return
# # changes the names of the modules to common ones
# keymap = self.sd_ref().network.get_keymap()
# save_keymap = {}
# if keymap is not None:
# for ldm_key, diffusers_key in keymap.items():
# # invert them
# save_keymap[diffusers_key] = ldm_key
#
# new_keymap = {}
# for key, value in self.weight_mapping:
# if key in save_keymap:
# new_keymap[save_keymap[key]] = value
# else:
# print(f"Key {key} not found in keymap")
# new_keymap[key] = value
# self.weight_mapping = new_keymap
# else:
# print("No keymap found. Using default names")
# return
def forward(self, img_embeds):
@@ -291,7 +294,8 @@ class InstantLoRAModule(torch.nn.Module):
img_embeds = self.resampler(img_embeds)
img_embeds = self.proj_module(img_embeds)
if len(img_embeds.shape) == 3:
img_embeds = img_embeds.squeeze(1)
# merge the heads
img_embeds = img_embeds.mean(dim=1)
self.img_embeds = []
# get all the slices
@@ -304,6 +308,11 @@ class InstantLoRAModule(torch.nn.Module):
def get_additional_save_metadata(self) -> Dict[str, Any]:
# save the weight mapping
return {
"weight_mapping": self.weight_mapping
"weight_mapping": self.weight_mapping,
"num_heads": self.num_heads,
"vision_hidden_size": self.vision_hidden_size,
"head_dim": self.head_dim,
"vision_tokens": self.vision_tokens,
"output_size": self.output_size,
}