Added single value adapter training

This commit is contained in:
Jaret Burkett
2024-04-28 06:04:47 -06:00
parent b96913d73c
commit 10e1ecf1e8
8 changed files with 462 additions and 7 deletions

View File

@@ -9,6 +9,7 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, T5En
from toolkit.models.clip_fusion import CLIPFusionModule
from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
from toolkit.models.ilora import InstantLoRAModule
from toolkit.models.single_value_adapter import SingleValueAdapter
from toolkit.models.te_adapter import TEAdapter
from toolkit.models.te_aug_adapter import TEAugAdapter
from toolkit.models.vd_adapter import VisionDirectAdapter
@@ -87,6 +88,7 @@ class CustomAdapter(torch.nn.Module):
self.te_adapter: TEAdapter = None
self.te_augmenter: TEAugAdapter = None
self.vd_adapter: VisionDirectAdapter = None
self.single_value_adapter: SingleValueAdapter = None
self.conditional_embeds: Optional[torch.Tensor] = None
self.unconditional_embeds: Optional[torch.Tensor] = None
@@ -173,6 +175,8 @@ class CustomAdapter(torch.nn.Module):
self.te_augmenter = TEAugAdapter(self, self.sd_ref())
elif self.adapter_type == 'vision_direct':
self.vd_adapter = VisionDirectAdapter(self, self.sd_ref(), self.vision_encoder)
elif self.adapter_type == 'single_value':
self.single_value_adapter = SingleValueAdapter(self, self.sd_ref(), num_values=self.config.num_tokens)
else:
raise ValueError(f"unknown adapter type: {self.adapter_type}")
@@ -204,7 +208,7 @@ class CustomAdapter(torch.nn.Module):
def setup_clip(self):
adapter_config = self.config
sd = self.sd_ref()
if self.config.type == "text_encoder":
if self.config.type == "text_encoder" or self.config.type == "single_value":
return
if self.config.type == 'photo_maker':
try:
@@ -374,6 +378,9 @@ class CustomAdapter(torch.nn.Module):
if 'dvadapter' in state_dict:
self.vd_adapter.load_state_dict(state_dict['dvadapter'], strict=strict)
if 'sv_adapter' in state_dict:
self.single_value_adapter.load_state_dict(state_dict['sv_adapter'], strict=strict)
if 'vision_encoder' in state_dict and self.config.train_image_encoder:
self.vision_encoder.load_state_dict(state_dict['vision_encoder'], strict=strict)
@@ -417,6 +424,9 @@ class CustomAdapter(torch.nn.Module):
if self.config.train_image_encoder:
state_dict["vision_encoder"] = self.vision_encoder.state_dict()
return state_dict
elif self.adapter_type == 'single_value':
state_dict["sv_adapter"] = self.single_value_adapter.state_dict()
return state_dict
elif self.adapter_type == 'ilora':
if self.config.train_image_encoder:
state_dict["vision_encoder"] = self.vision_encoder.state_dict()
@@ -425,6 +435,14 @@ class CustomAdapter(torch.nn.Module):
else:
raise NotImplementedError
def add_extra_values(self, extra_values: torch.Tensor, is_unconditional=False):
if self.adapter_type == 'single_value':
if is_unconditional:
self.unconditional_embeds = extra_values.to(self.device, get_torch_dtype(self.sd_ref().dtype))
else:
self.conditional_embeds = extra_values.to(self.device, get_torch_dtype(self.sd_ref().dtype))
def condition_prompt(
self,
prompt: Union[List[str], str],
@@ -843,6 +861,8 @@ class CustomAdapter(torch.nn.Module):
yield from self.te_augmenter.parameters(recurse)
if self.config.train_image_encoder:
yield from self.vision_encoder.parameters(recurse)
elif self.config.type == 'single_value':
yield from self.single_value_adapter.parameters(recurse)
else:
raise NotImplementedError