mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 18:51:37 +00:00
Added working ilora trainer
This commit is contained in:
@@ -19,7 +19,7 @@ from toolkit.saving import load_ip_adapter_model, load_custom_adapter_model
|
||||
from toolkit.train_tools import get_torch_dtype
|
||||
|
||||
sys.path.append(REPOS_ROOT)
|
||||
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional
|
||||
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List, Optional, Dict
|
||||
from collections import OrderedDict
|
||||
from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, \
|
||||
AttnProcessor2_0
|
||||
@@ -145,6 +145,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
self.ilora_module = InstantLoRAModule(
|
||||
vision_tokens=vision_tokens,
|
||||
vision_hidden_size=vision_hidden_size,
|
||||
head_dim=1024,
|
||||
sd=self.sd_ref()
|
||||
)
|
||||
elif self.adapter_type == 'text_encoder':
|
||||
@@ -875,3 +876,8 @@ class CustomAdapter(torch.nn.Module):
|
||||
self.vision_encoder.enable_gradient_checkpointing()
|
||||
elif hasattr(self.vision_encoder, 'gradient_checkpointing'):
|
||||
self.vision_encoder.gradient_checkpointing = True
|
||||
|
||||
def get_additional_save_metadata(self) -> Dict[str, Any]:
|
||||
if self.config.type == 'ilora':
|
||||
return self.ilora_module.get_additional_save_metadata()
|
||||
return {}
|
||||
Reference in New Issue
Block a user