WIP Ilora

This commit is contained in:
Jaret Burkett
2024-06-14 09:31:01 -06:00
parent bd10d2d668
commit 37cebd9458
6 changed files with 57 additions and 29 deletions

View File

@@ -145,7 +145,8 @@ class CustomAdapter(torch.nn.Module):
self.ilora_module = InstantLoRAModule(
vision_tokens=vision_tokens,
vision_hidden_size=vision_hidden_size,
head_dim=1024,
head_dim=self.config.head_dim,
num_heads=self.config.num_heads,
sd=self.sd_ref()
)
elif self.adapter_type == 'text_encoder':
@@ -878,6 +879,11 @@ class CustomAdapter(torch.nn.Module):
self.vision_encoder.gradient_checkpointing = True
def get_additional_save_metadata(self) -> Dict[str, Any]:
additional = {}
if self.config.type == 'ilora':
return self.ilora_module.get_additional_save_metadata()
return {}
extra = self.ilora_module.get_additional_save_metadata()
for k, v in extra.items():
additional[k] = v
additional['clip_layer'] = self.config.clip_layer
additional['image_encoder_arch'] = self.config.head_dim
return additional