Added working ilora trainer

This commit is contained in:
Jaret Burkett
2024-06-12 09:33:45 -06:00
parent 3f3636b788
commit cb5d28cba9
6 changed files with 261 additions and 196 deletions

View File

@@ -19,7 +19,7 @@ from toolkit.saving import load_ip_adapter_model, load_custom_adapter_model
from toolkit.train_tools import get_torch_dtype
sys.path.append(REPOS_ROOT)
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict
from collections import OrderedDict
from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, \
AttnProcessor2_0
@@ -145,6 +145,7 @@ class CustomAdapter(torch.nn.Module):
self.ilora_module = InstantLoRAModule(
vision_tokens=vision_tokens,
vision_hidden_size=vision_hidden_size,
head_dim=1024,
sd=self.sd_ref()
)
elif self.adapter_type == 'text_encoder':
@@ -875,3 +876,8 @@ class CustomAdapter(torch.nn.Module):
self.vision_encoder.enable_gradient_checkpointing()
elif hasattr(self.vision_encoder, 'gradient_checkpointing'):
self.vision_encoder.gradient_checkpointing = True
def get_additional_save_metadata(self) -> Dict[str, Any]:
if self.config.type == 'ilora':
return self.ilora_module.get_additional_save_metadata()
return {}