mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-21 14:59:02 +00:00
Added single value adapter training
This commit is contained in:
@@ -9,6 +9,7 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, T5En
|
||||
from toolkit.models.clip_fusion import CLIPFusionModule
|
||||
from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
|
||||
from toolkit.models.ilora import InstantLoRAModule
|
||||
from toolkit.models.single_value_adapter import SingleValueAdapter
|
||||
from toolkit.models.te_adapter import TEAdapter
|
||||
from toolkit.models.te_aug_adapter import TEAugAdapter
|
||||
from toolkit.models.vd_adapter import VisionDirectAdapter
|
||||
@@ -87,6 +88,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
self.te_adapter: TEAdapter = None
|
||||
self.te_augmenter: TEAugAdapter = None
|
||||
self.vd_adapter: VisionDirectAdapter = None
|
||||
self.single_value_adapter: SingleValueAdapter = None
|
||||
self.conditional_embeds: Optional[torch.Tensor] = None
|
||||
self.unconditional_embeds: Optional[torch.Tensor] = None
|
||||
|
||||
@@ -173,6 +175,8 @@ class CustomAdapter(torch.nn.Module):
|
||||
self.te_augmenter = TEAugAdapter(self, self.sd_ref())
|
||||
elif self.adapter_type == 'vision_direct':
|
||||
self.vd_adapter = VisionDirectAdapter(self, self.sd_ref(), self.vision_encoder)
|
||||
elif self.adapter_type == 'single_value':
|
||||
self.single_value_adapter = SingleValueAdapter(self, self.sd_ref(), num_values=self.config.num_tokens)
|
||||
else:
|
||||
raise ValueError(f"unknown adapter type: {self.adapter_type}")
|
||||
|
||||
@@ -204,7 +208,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
def setup_clip(self):
|
||||
adapter_config = self.config
|
||||
sd = self.sd_ref()
|
||||
if self.config.type == "text_encoder":
|
||||
if self.config.type == "text_encoder" or self.config.type == "single_value":
|
||||
return
|
||||
if self.config.type == 'photo_maker':
|
||||
try:
|
||||
@@ -374,6 +378,9 @@ class CustomAdapter(torch.nn.Module):
|
||||
if 'dvadapter' in state_dict:
|
||||
self.vd_adapter.load_state_dict(state_dict['dvadapter'], strict=strict)
|
||||
|
||||
if 'sv_adapter' in state_dict:
|
||||
self.single_value_adapter.load_state_dict(state_dict['sv_adapter'], strict=strict)
|
||||
|
||||
if 'vision_encoder' in state_dict and self.config.train_image_encoder:
|
||||
self.vision_encoder.load_state_dict(state_dict['vision_encoder'], strict=strict)
|
||||
|
||||
@@ -417,6 +424,9 @@ class CustomAdapter(torch.nn.Module):
|
||||
if self.config.train_image_encoder:
|
||||
state_dict["vision_encoder"] = self.vision_encoder.state_dict()
|
||||
return state_dict
|
||||
elif self.adapter_type == 'single_value':
|
||||
state_dict["sv_adapter"] = self.single_value_adapter.state_dict()
|
||||
return state_dict
|
||||
elif self.adapter_type == 'ilora':
|
||||
if self.config.train_image_encoder:
|
||||
state_dict["vision_encoder"] = self.vision_encoder.state_dict()
|
||||
@@ -425,6 +435,14 @@ class CustomAdapter(torch.nn.Module):
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def add_extra_values(self, extra_values: torch.Tensor, is_unconditional=False):
|
||||
if self.adapter_type == 'single_value':
|
||||
if is_unconditional:
|
||||
self.unconditional_embeds = extra_values.to(self.device, get_torch_dtype(self.sd_ref().dtype))
|
||||
else:
|
||||
self.conditional_embeds = extra_values.to(self.device, get_torch_dtype(self.sd_ref().dtype))
|
||||
|
||||
|
||||
def condition_prompt(
|
||||
self,
|
||||
prompt: Union[List[str], str],
|
||||
@@ -843,6 +861,8 @@ class CustomAdapter(torch.nn.Module):
|
||||
yield from self.te_augmenter.parameters(recurse)
|
||||
if self.config.train_image_encoder:
|
||||
yield from self.vision_encoder.parameters(recurse)
|
||||
elif self.config.type == 'single_value':
|
||||
yield from self.single_value_adapter.parameters(recurse)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
Reference in New Issue
Block a user