mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 18:51:37 +00:00
Added te aug adapter
This commit is contained in:
@@ -10,6 +10,7 @@ from toolkit.models.clip_fusion import CLIPFusionModule
|
||||
from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
|
||||
from toolkit.models.ilora import InstantLoRAModule
|
||||
from toolkit.models.te_adapter import TEAdapter
|
||||
from toolkit.models.te_aug_adapter import TEAugAdapter
|
||||
from toolkit.models.vd_adapter import VisionDirectAdapter
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
from toolkit.photomaker import PhotoMakerIDEncoder, FuseModule, PhotoMakerCLIPEncoder
|
||||
@@ -60,6 +61,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
self.current_scale = 1.0
|
||||
self.is_active = True
|
||||
self.flag_word = "fla9wor0"
|
||||
self.is_unconditional_run = False
|
||||
|
||||
self.vision_encoder: Union[PhotoMakerCLIPEncoder, CLIPVisionModelWithProjection] = None
|
||||
|
||||
@@ -83,6 +85,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
self.te: Union[T5EncoderModel, CLIPTextModel] = None
|
||||
self.tokenizer: CLIPTokenizer = None
|
||||
self.te_adapter: TEAdapter = None
|
||||
self.te_augmenter: TEAugAdapter = None
|
||||
self.vd_adapter: VisionDirectAdapter = None
|
||||
self.conditional_embeds: Optional[torch.Tensor] = None
|
||||
self.unconditional_embeds: Optional[torch.Tensor] = None
|
||||
@@ -149,6 +152,8 @@ class CustomAdapter(torch.nn.Module):
|
||||
raise ValueError(f"unknown text encoder arch: {self.config.text_encoder_arch}")
|
||||
|
||||
self.te_adapter = TEAdapter(self, self.sd_ref(), self.te, self.tokenizer)
|
||||
elif self.adapter_type == 'te_augmenter':
|
||||
self.te_augmenter = TEAugAdapter(self, self.sd_ref())
|
||||
elif self.adapter_type == 'vision_direct':
|
||||
self.vd_adapter = VisionDirectAdapter(self, self.sd_ref(), self.vision_encoder)
|
||||
else:
|
||||
@@ -269,9 +274,13 @@ class CustomAdapter(torch.nn.Module):
|
||||
preprocessor_input_size = self.vision_encoder.config.image_size * 2
|
||||
|
||||
# update the preprocessor so images come in at the right size
|
||||
self.image_processor.size['shortest_edge'] = preprocessor_input_size
|
||||
self.image_processor.crop_size['height'] = preprocessor_input_size
|
||||
self.image_processor.crop_size['width'] = preprocessor_input_size
|
||||
if 'height' in self.image_processor.size:
|
||||
self.image_processor.size['height'] = preprocessor_input_size
|
||||
self.image_processor.size['width'] = preprocessor_input_size
|
||||
elif hasattr(self.image_processor, 'crop_size'):
|
||||
self.image_processor.size['shortest_edge'] = preprocessor_input_size
|
||||
self.image_processor.crop_size['height'] = preprocessor_input_size
|
||||
self.image_processor.crop_size['width'] = preprocessor_input_size
|
||||
|
||||
if self.config.image_encoder_arch == 'clip+':
|
||||
# self.image_processor.config
|
||||
@@ -340,6 +349,9 @@ class CustomAdapter(torch.nn.Module):
|
||||
if 'te_adapter' in state_dict:
|
||||
self.te_adapter.load_state_dict(state_dict['te_adapter'], strict=strict)
|
||||
|
||||
if 'te_augmenter' in state_dict:
|
||||
self.te_augmenter.load_state_dict(state_dict['te_augmenter'], strict=strict)
|
||||
|
||||
if 'vd_adapter' in state_dict:
|
||||
self.vd_adapter.load_state_dict(state_dict['vd_adapter'], strict=strict)
|
||||
if 'dvadapter' in state_dict:
|
||||
@@ -378,6 +390,11 @@ class CustomAdapter(torch.nn.Module):
|
||||
elif self.adapter_type == 'text_encoder':
|
||||
state_dict["te_adapter"] = self.te_adapter.state_dict()
|
||||
return state_dict
|
||||
elif self.adapter_type == 'te_augmenter':
|
||||
if self.config.train_image_encoder:
|
||||
state_dict["vision_encoder"] = self.vision_encoder.state_dict()
|
||||
state_dict["te_augmenter"] = self.te_augmenter.state_dict()
|
||||
return state_dict
|
||||
elif self.adapter_type == 'vision_direct':
|
||||
state_dict["dvadapter"] = self.vd_adapter.state_dict()
|
||||
if self.config.train_image_encoder:
|
||||
@@ -647,7 +664,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
has_been_preprocessed=False,
|
||||
quad_count=4,
|
||||
) -> PromptEmbeds:
|
||||
if self.adapter_type == 'ilora' or self.adapter_type == 'vision_direct':
|
||||
if self.adapter_type == 'ilora' or self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter':
|
||||
if tensors_0_1 is None:
|
||||
tensors_0_1 = self.get_empty_clip_image(1)
|
||||
has_been_preprocessed = True
|
||||
@@ -675,7 +692,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
clip_image = tensors_0_1
|
||||
|
||||
batch_size = clip_image.shape[0]
|
||||
if self.adapter_type == 'vision_direct':
|
||||
if self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter':
|
||||
# add an unconditional so we can save it
|
||||
unconditional = self.get_empty_clip_image(batch_size).to(
|
||||
clip_image.device, dtype=clip_image.dtype
|
||||
@@ -730,7 +747,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
img_embeds = img_embeds.detach()
|
||||
|
||||
self.ilora_module(img_embeds)
|
||||
if self.adapter_type == 'vision_direct':
|
||||
if self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter':
|
||||
with torch.set_grad_enabled(is_training):
|
||||
if is_training and self.config.train_image_encoder:
|
||||
self.vision_encoder.train()
|
||||
@@ -754,8 +771,14 @@ class CustomAdapter(torch.nn.Module):
|
||||
if not is_training or not self.config.train_image_encoder:
|
||||
clip_image_embeds = clip_image_embeds.detach()
|
||||
|
||||
if self.adapter_type == 'te_augmenter':
|
||||
clip_image_embeds = self.te_augmenter(clip_image_embeds)
|
||||
|
||||
# save them to the conditional and unconditional
|
||||
self.unconditional_embeds, self.conditional_embeds = clip_image_embeds.chunk(2, dim=0)
|
||||
try:
|
||||
self.unconditional_embeds, self.conditional_embeds = clip_image_embeds.chunk(2, dim=0)
|
||||
except ValueError:
|
||||
raise ValueError(f"could not split the clip image embeds into 2. Got shape: {clip_image_embeds.shape}")
|
||||
|
||||
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
||||
if self.config.train_only_image_encoder:
|
||||
@@ -781,6 +804,10 @@ class CustomAdapter(torch.nn.Module):
|
||||
yield from attn_processor.parameters(recurse)
|
||||
if self.config.train_image_encoder:
|
||||
yield from self.vision_encoder.parameters(recurse)
|
||||
elif self.config.type == 'te_augmenter':
|
||||
yield from self.te_augmenter.parameters(recurse)
|
||||
if self.config.train_image_encoder:
|
||||
yield from self.vision_encoder.parameters(recurse)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
Reference in New Issue
Block a user