Added te aug adapter

This commit is contained in:
Jaret Burkett
2024-02-21 21:30:26 -07:00
parent 49c41e6a5f
commit b68c3ef734
5 changed files with 310 additions and 8 deletions

View File

@@ -10,6 +10,7 @@ from toolkit.models.clip_fusion import CLIPFusionModule
from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
from toolkit.models.ilora import InstantLoRAModule
from toolkit.models.te_adapter import TEAdapter
from toolkit.models.te_aug_adapter import TEAugAdapter
from toolkit.models.vd_adapter import VisionDirectAdapter
from toolkit.paths import REPOS_ROOT
from toolkit.photomaker import PhotoMakerIDEncoder, FuseModule, PhotoMakerCLIPEncoder
@@ -60,6 +61,7 @@ class CustomAdapter(torch.nn.Module):
self.current_scale = 1.0
self.is_active = True
self.flag_word = "fla9wor0"
self.is_unconditional_run = False
self.vision_encoder: Union[PhotoMakerCLIPEncoder, CLIPVisionModelWithProjection] = None
@@ -83,6 +85,7 @@ class CustomAdapter(torch.nn.Module):
self.te: Union[T5EncoderModel, CLIPTextModel] = None
self.tokenizer: CLIPTokenizer = None
self.te_adapter: TEAdapter = None
self.te_augmenter: TEAugAdapter = None
self.vd_adapter: VisionDirectAdapter = None
self.conditional_embeds: Optional[torch.Tensor] = None
self.unconditional_embeds: Optional[torch.Tensor] = None
@@ -149,6 +152,8 @@ class CustomAdapter(torch.nn.Module):
raise ValueError(f"unknown text encoder arch: {self.config.text_encoder_arch}")
self.te_adapter = TEAdapter(self, self.sd_ref(), self.te, self.tokenizer)
elif self.adapter_type == 'te_augmenter':
self.te_augmenter = TEAugAdapter(self, self.sd_ref())
elif self.adapter_type == 'vision_direct':
self.vd_adapter = VisionDirectAdapter(self, self.sd_ref(), self.vision_encoder)
else:
@@ -269,9 +274,13 @@ class CustomAdapter(torch.nn.Module):
preprocessor_input_size = self.vision_encoder.config.image_size * 2
# update the preprocessor so images come in at the right size
self.image_processor.size['shortest_edge'] = preprocessor_input_size
self.image_processor.crop_size['height'] = preprocessor_input_size
self.image_processor.crop_size['width'] = preprocessor_input_size
if 'height' in self.image_processor.size:
self.image_processor.size['height'] = preprocessor_input_size
self.image_processor.size['width'] = preprocessor_input_size
elif hasattr(self.image_processor, 'crop_size'):
self.image_processor.size['shortest_edge'] = preprocessor_input_size
self.image_processor.crop_size['height'] = preprocessor_input_size
self.image_processor.crop_size['width'] = preprocessor_input_size
if self.config.image_encoder_arch == 'clip+':
# self.image_processor.config
@@ -340,6 +349,9 @@ class CustomAdapter(torch.nn.Module):
if 'te_adapter' in state_dict:
self.te_adapter.load_state_dict(state_dict['te_adapter'], strict=strict)
if 'te_augmenter' in state_dict:
self.te_augmenter.load_state_dict(state_dict['te_augmenter'], strict=strict)
if 'vd_adapter' in state_dict:
self.vd_adapter.load_state_dict(state_dict['vd_adapter'], strict=strict)
if 'dvadapter' in state_dict:
@@ -378,6 +390,11 @@ class CustomAdapter(torch.nn.Module):
elif self.adapter_type == 'text_encoder':
state_dict["te_adapter"] = self.te_adapter.state_dict()
return state_dict
elif self.adapter_type == 'te_augmenter':
if self.config.train_image_encoder:
state_dict["vision_encoder"] = self.vision_encoder.state_dict()
state_dict["te_augmenter"] = self.te_augmenter.state_dict()
return state_dict
elif self.adapter_type == 'vision_direct':
state_dict["dvadapter"] = self.vd_adapter.state_dict()
if self.config.train_image_encoder:
@@ -647,7 +664,7 @@ class CustomAdapter(torch.nn.Module):
has_been_preprocessed=False,
quad_count=4,
) -> PromptEmbeds:
if self.adapter_type == 'ilora' or self.adapter_type == 'vision_direct':
if self.adapter_type == 'ilora' or self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter':
if tensors_0_1 is None:
tensors_0_1 = self.get_empty_clip_image(1)
has_been_preprocessed = True
@@ -675,7 +692,7 @@ class CustomAdapter(torch.nn.Module):
clip_image = tensors_0_1
batch_size = clip_image.shape[0]
if self.adapter_type == 'vision_direct':
if self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter':
# add an unconditional so we can save it
unconditional = self.get_empty_clip_image(batch_size).to(
clip_image.device, dtype=clip_image.dtype
@@ -730,7 +747,7 @@ class CustomAdapter(torch.nn.Module):
img_embeds = img_embeds.detach()
self.ilora_module(img_embeds)
if self.adapter_type == 'vision_direct':
if self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter':
with torch.set_grad_enabled(is_training):
if is_training and self.config.train_image_encoder:
self.vision_encoder.train()
@@ -754,8 +771,14 @@ class CustomAdapter(torch.nn.Module):
if not is_training or not self.config.train_image_encoder:
clip_image_embeds = clip_image_embeds.detach()
if self.adapter_type == 'te_augmenter':
clip_image_embeds = self.te_augmenter(clip_image_embeds)
# save them to the conditional and unconditional
self.unconditional_embeds, self.conditional_embeds = clip_image_embeds.chunk(2, dim=0)
try:
self.unconditional_embeds, self.conditional_embeds = clip_image_embeds.chunk(2, dim=0)
except ValueError:
raise ValueError(f"could not split the clip image embeds into 2. Got shape: {clip_image_embeds.shape}")
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
if self.config.train_only_image_encoder:
@@ -781,6 +804,10 @@ class CustomAdapter(torch.nn.Module):
yield from attn_processor.parameters(recurse)
if self.config.train_image_encoder:
yield from self.vision_encoder.parameters(recurse)
elif self.config.type == 'te_augmenter':
yield from self.te_augmenter.parameters(recurse)
if self.config.train_image_encoder:
yield from self.vision_encoder.parameters(recurse)
else:
raise NotImplementedError