mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 18:51:37 +00:00
WIP Ilora
This commit is contained in:
@@ -181,6 +181,7 @@ class InstantLoRAModule(torch.nn.Module):
|
||||
vision_hidden_size: int,
|
||||
vision_tokens: int,
|
||||
head_dim: int,
|
||||
num_heads: int, # number of heads in the resampler
|
||||
sd: 'StableDiffusion'
|
||||
):
|
||||
super(InstantLoRAModule, self).__init__()
|
||||
@@ -190,6 +191,7 @@ class InstantLoRAModule(torch.nn.Module):
|
||||
self.vision_hidden_size = vision_hidden_size
|
||||
self.vision_tokens = vision_tokens
|
||||
self.head_dim = head_dim
|
||||
self.num_heads = num_heads
|
||||
|
||||
# stores the projection vector. Grabbed by modules
|
||||
self.img_embeds: List[torch.Tensor] = None
|
||||
@@ -243,7 +245,7 @@ class InstantLoRAModule(torch.nn.Module):
|
||||
depth=4,
|
||||
dim_head=64,
|
||||
heads=12,
|
||||
num_queries=1, # output tokens
|
||||
num_queries=num_heads, # output tokens
|
||||
embedding_dim=vision_hidden_size,
|
||||
max_seq_len=vision_tokens,
|
||||
output_dim=head_dim,
|
||||
@@ -261,25 +263,26 @@ class InstantLoRAModule(torch.nn.Module):
|
||||
self.migrate_weight_mapping()
|
||||
|
||||
def migrate_weight_mapping(self):
|
||||
# changes the names of the modules to common ones
|
||||
keymap = self.sd_ref().network.get_keymap()
|
||||
save_keymap = {}
|
||||
if keymap is not None:
|
||||
for ldm_key, diffusers_key in keymap.items():
|
||||
# invert them
|
||||
save_keymap[diffusers_key] = ldm_key
|
||||
|
||||
new_keymap = {}
|
||||
for key, value in self.weight_mapping:
|
||||
if key in save_keymap:
|
||||
new_keymap[save_keymap[key]] = value
|
||||
else:
|
||||
print(f"Key {key} not found in keymap")
|
||||
new_keymap[key] = value
|
||||
self.weight_mapping = new_keymap
|
||||
else:
|
||||
print("No keymap found. Using default names")
|
||||
return
|
||||
return
|
||||
# # changes the names of the modules to common ones
|
||||
# keymap = self.sd_ref().network.get_keymap()
|
||||
# save_keymap = {}
|
||||
# if keymap is not None:
|
||||
# for ldm_key, diffusers_key in keymap.items():
|
||||
# # invert them
|
||||
# save_keymap[diffusers_key] = ldm_key
|
||||
#
|
||||
# new_keymap = {}
|
||||
# for key, value in self.weight_mapping:
|
||||
# if key in save_keymap:
|
||||
# new_keymap[save_keymap[key]] = value
|
||||
# else:
|
||||
# print(f"Key {key} not found in keymap")
|
||||
# new_keymap[key] = value
|
||||
# self.weight_mapping = new_keymap
|
||||
# else:
|
||||
# print("No keymap found. Using default names")
|
||||
# return
|
||||
|
||||
|
||||
def forward(self, img_embeds):
|
||||
@@ -291,7 +294,8 @@ class InstantLoRAModule(torch.nn.Module):
|
||||
img_embeds = self.resampler(img_embeds)
|
||||
img_embeds = self.proj_module(img_embeds)
|
||||
if len(img_embeds.shape) == 3:
|
||||
img_embeds = img_embeds.squeeze(1)
|
||||
# merge the heads
|
||||
img_embeds = img_embeds.mean(dim=1)
|
||||
|
||||
self.img_embeds = []
|
||||
# get all the slices
|
||||
@@ -304,6 +308,11 @@ class InstantLoRAModule(torch.nn.Module):
|
||||
def get_additional_save_metadata(self) -> Dict[str, Any]:
|
||||
# save the weight mapping
|
||||
return {
|
||||
"weight_mapping": self.weight_mapping
|
||||
"weight_mapping": self.weight_mapping,
|
||||
"num_heads": self.num_heads,
|
||||
"vision_hidden_size": self.vision_hidden_size,
|
||||
"head_dim": self.head_dim,
|
||||
"vision_tokens": self.vision_tokens,
|
||||
"output_size": self.output_size,
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user