Added an experimental clip fusion model that is showing promise for embedding concepts

This commit is contained in:
Jaret Burkett
2024-01-17 13:13:04 -07:00
parent 655533d4c7
commit 86c70a2a1f
3 changed files with 320 additions and 89 deletions

View File

@@ -5,6 +5,7 @@ from PIL import Image
from torch.nn import Parameter
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from toolkit.models.clip_fusion import CLIPFusionModule
from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
from toolkit.paths import REPOS_ROOT
from toolkit.photomaker import PhotoMakerIDEncoder, FuseModule, PhotoMakerCLIPEncoder
@@ -72,50 +73,66 @@ class CustomAdapter(torch.nn.Module):
# add for dataloader
self.clip_image_processor = self.image_processor
self.clip_fusion_module: CLIPFusionModule = None
self.setup_adapter()
# try to load from our name_or_path
if self.config.name_or_path is not None and self.config.name_or_path.endswith('.bin'):
self.load_state_dict(torch.load(self.config.name_or_path, map_location=self.device), strict=False)
# add the trigger word to the tokenizer
if isinstance(self.sd_ref().tokenizer, list):
for tokenizer in self.sd_ref().tokenizer:
tokenizer.add_tokens([self.flag_word], special_tokens=True)
else:
self.sd_ref().tokenizer.add_tokens([self.flag_word], special_tokens=True)
if self.adapter_type == 'photo_maker':
# try to load from our name_or_path
if self.config.name_or_path is not None and self.config.name_or_path.endswith('.bin'):
self.load_state_dict(torch.load(self.config.name_or_path, map_location=self.device), strict=False)
# add the trigger word to the tokenizer
if isinstance(self.sd_ref().tokenizer, list):
for tokenizer in self.sd_ref().tokenizer:
tokenizer.add_tokens([self.flag_word], special_tokens=True)
else:
self.sd_ref().tokenizer.add_tokens([self.flag_word], special_tokens=True)
def setup_adapter(self):
if self.adapter_type == 'photo_maker':
sd = self.sd_ref()
embed_dim = sd.unet.config['cross_attention_dim']
self.fuse_module = FuseModule(embed_dim)
elif self.adapter_type == 'clip_fusion':
sd = self.sd_ref()
embed_dim = sd.unet.config['cross_attention_dim']
vision_tokens = ((self.vision_encoder.config.image_size // self.vision_encoder.config.patch_size) ** 2)
if self.config.image_encoder_arch == 'clip':
vision_tokens = vision_tokens + 1
self.clip_fusion_module = CLIPFusionModule(
text_hidden_size=embed_dim,
text_tokens=77,
vision_hidden_size=self.vision_encoder.config.hidden_size,
vision_tokens=vision_tokens
)
else:
raise ValueError(f"unknown adapter type: {self.adapter_type}")
def forward(self, *args, **kwargs):
if self.adapter_type == 'photo_maker':
id_pixel_values = args[0]
prompt_embeds: PromptEmbeds = args[1]
class_tokens_mask = args[2]
grads_on_image_encoder = self.config.train_image_encoder and torch.is_grad_enabled()
with torch.set_grad_enabled(grads_on_image_encoder):
id_embeds = self.vision_encoder(self, id_pixel_values, do_projection2=False)
if not grads_on_image_encoder:
id_embeds = id_embeds.detach()
prompt_embeds = prompt_embeds.detach()
updated_prompt_embeds = self.fuse_module(
prompt_embeds, id_embeds, class_tokens_mask
)
return updated_prompt_embeds
else:
raise NotImplementedError
# dont think this is used
# if self.adapter_type == 'photo_maker':
# id_pixel_values = args[0]
# prompt_embeds: PromptEmbeds = args[1]
# class_tokens_mask = args[2]
#
# grads_on_image_encoder = self.config.train_image_encoder and torch.is_grad_enabled()
#
# with torch.set_grad_enabled(grads_on_image_encoder):
# id_embeds = self.vision_encoder(self, id_pixel_values, do_projection2=False)
#
# if not grads_on_image_encoder:
# id_embeds = id_embeds.detach()
#
# prompt_embeds = prompt_embeds.detach()
#
# updated_prompt_embeds = self.fuse_module(
# prompt_embeds, id_embeds, class_tokens_mask
# )
#
# return updated_prompt_embeds
# else:
raise NotImplementedError
def setup_clip(self):
adapter_config = self.config
@@ -226,7 +243,9 @@ class CustomAdapter(torch.nn.Module):
# self.sd_ref().pipeline.load_lora_weights(state_dict["lora_weights"], adapter_name="photomaker")
# self.sd_ref().pipeline.fuse_lora()
pass
if 'id_encoder' in state_dict and self.adapter_type == 'photo_maker':
if 'clip_fusion' in state_dict:
self.clip_fusion_module.load_state_dict(state_dict['clip_fusion'], strict=strict)
if 'id_encoder' in state_dict and (self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion'):
self.vision_encoder.load_state_dict(state_dict['id_encoder'], strict=strict)
# check to see if the fuse weights are there
fuse_weights = {}
@@ -235,7 +254,31 @@ class CustomAdapter(torch.nn.Module):
k = k.replace('fuse_module.', '')
fuse_weights[k] = v
if len(fuse_weights) > 0:
self.fuse_module.load_state_dict(fuse_weights, strict=strict)
try:
self.fuse_module.load_state_dict(fuse_weights, strict=strict)
except Exception as e:
print(e)
# force load it
print(f"force loading fuse module as it did not match")
current_state_dict = self.fuse_module.state_dict()
for k, v in fuse_weights.items():
if len(v.shape) == 1:
current_state_dict[k] = v[:current_state_dict[k].shape[0]]
elif len(v.shape) == 2:
current_state_dict[k] = v[:current_state_dict[k].shape[0], :current_state_dict[k].shape[1]]
elif len(v.shape) == 3:
current_state_dict[k] = v[:current_state_dict[k].shape[0], :current_state_dict[k].shape[1],
:current_state_dict[k].shape[2]]
elif len(v.shape) == 4:
current_state_dict[k] = v[:current_state_dict[k].shape[0], :current_state_dict[k].shape[1],
:current_state_dict[k].shape[2], :current_state_dict[k].shape[3]]
else:
raise ValueError(f"unknown shape: {v.shape}")
self.fuse_module.load_state_dict(current_state_dict, strict=strict)
if 'vision_encoder' in state_dict:
self.vision_encoder.load_state_dict(state_dict['vision_encoder'], strict=strict)
if 'fuse_module' in state_dict:
self.fuse_module.load_state_dict(state_dict['fuse_module'], strict=strict)
@@ -252,6 +295,12 @@ class CustomAdapter(torch.nn.Module):
# todo save LoRA
return state_dict
elif self.adapter_type == 'clip_fusion':
if self.config.train_image_encoder:
state_dict["vision_encoder"] = self.vision_encoder.state_dict()
state_dict["clip_fusion"] = self.clip_fusion_module.state_dict()
return state_dict
else:
raise NotImplementedError
@@ -260,7 +309,9 @@ class CustomAdapter(torch.nn.Module):
prompt: Union[List[str], str],
is_unconditional: bool = False,
):
if self.adapter_type == 'photo_maker':
if self.adapter_type == 'clip_fusion':
return prompt
elif self.adapter_type == 'photo_maker':
if is_unconditional:
return prompt
else:
@@ -310,7 +361,8 @@ class CustomAdapter(torch.nn.Module):
# add the first one to the front of the prompt
tokened_prompt = self.config.class_names[0] + ' ' + self.flag_word + ' ' + prompt
our_class = self.config.class_names[0]
prompt = " ".join([self.config.class_names[0] for _ in range(self.num_control_images)]) + ' ' + prompt
prompt = " ".join(
[self.config.class_names[0] for _ in range(self.num_control_images)]) + ' ' + prompt
# add the prompt to the list
new_prompt_list.append(prompt)
@@ -329,8 +381,7 @@ class CustomAdapter(torch.nn.Module):
class_token = tokenized_prompt[flag_idx - 1]
boolean_mask = torch.zeros(flag_idx-1, dtype=torch.bool)
boolean_mask = torch.zeros(flag_idx - 1, dtype=torch.bool)
boolean_mask = torch.cat((boolean_mask, torch.ones(self.num_control_images, dtype=torch.bool)))
boolean_mask = boolean_mask.to(self.device)
# zero pad it to 77
@@ -357,62 +408,96 @@ class CustomAdapter(torch.nn.Module):
has_been_preprocessed=False,
is_unconditional=False
) -> PromptEmbeds:
if self.adapter_type == 'photo_maker':
if self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion':
if is_unconditional:
# we dont condition the negative embeds for photo maker
return prompt_embeds
with torch.no_grad():
# on training the clip image is created in the dataloader
if not has_been_preprocessed:
# tensors should be 0-1
if tensors_0_1.ndim == 3:
tensors_0_1 = tensors_0_1.unsqueeze(0)
# training tensors are 0 - 1
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
# if images are out of this range throw error
if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
tensors_0_1.min(), tensors_0_1.max()
))
clip_image = self.image_processor(
images=tensors_0_1,
return_tensors="pt",
do_resize=True,
do_rescale=False,
).pixel_values
else:
clip_image = tensors_0_1
clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach()
with torch.no_grad():
# on training the clip image is created in the dataloader
if not has_been_preprocessed:
# tensors should be 0-1
if tensors_0_1.ndim == 3:
tensors_0_1 = tensors_0_1.unsqueeze(0)
# training tensors are 0 - 1
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
# if images are out of this range throw error
if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
tensors_0_1.min(), tensors_0_1.max()
))
clip_image = self.image_processor(
images=tensors_0_1,
return_tensors="pt",
do_resize=True,
do_rescale=False,
).pixel_values
else:
clip_image = tensors_0_1
clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach()
if self.adapter_type == 'photo_maker':
# Embeddings need to be (b, num_inputs, c, h, w) for now, just put 1 input image
clip_image = clip_image.unsqueeze(1)
with torch.set_grad_enabled(is_training):
if is_training and self.config.train_image_encoder:
self.vision_encoder.train()
clip_image = clip_image.requires_grad_(True)
id_embeds = self.vision_encoder(
clip_image,
do_projection2=isinstance(self.sd_ref().text_encoder, list),
)
else:
with torch.no_grad():
self.vision_encoder.eval()
id_embeds = self.vision_encoder(
clip_image, do_projection2=isinstance(self.sd_ref().text_encoder, list)
).detach()
prompt_embeds.text_embeds = self.fuse_module(
prompt_embeds.text_embeds,
id_embeds,
self.token_mask
)
return prompt_embeds
elif self.adapter_type == 'clip_fusion':
with torch.set_grad_enabled(is_training):
if is_training and self.config.train_image_encoder:
self.vision_encoder.train()
clip_image = clip_image.requires_grad_(True)
id_embeds = self.vision_encoder(
clip_image,
output_hidden_states=True,
)
else:
with torch.no_grad():
self.vision_encoder.eval()
id_embeds = self.vision_encoder(
clip_image, output_hidden_states=True
)
img_embeds = id_embeds['last_hidden_state']
if not is_training or not self.config.train_image_encoder:
img_embeds = img_embeds.detach()
prompt_embeds.text_embeds = self.clip_fusion_module(
prompt_embeds.text_embeds,
img_embeds
)
return prompt_embeds
# Embeddings need to be (b, num_inputs, c, h, w) for now, just put 1 input image
clip_image = clip_image.unsqueeze(1)
with torch.set_grad_enabled(is_training):
if is_training and self.config.train_image_encoder:
self.vision_encoder.train()
clip_image = clip_image.requires_grad_(True)
id_embeds = self.vision_encoder(
clip_image,
do_projection2=isinstance(self.sd_ref().text_encoder, list),
)
else:
with torch.no_grad():
self.vision_encoder.eval()
id_embeds = self.vision_encoder(
clip_image, do_projection2=isinstance(self.sd_ref().text_encoder, list)
).detach()
prompt_embeds.text_embeds = self.fuse_module(
prompt_embeds.text_embeds,
id_embeds,
self.token_mask
)
return prompt_embeds
else:
raise NotImplementedError
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
if self.config.type == 'photo_maker':
yield from self.fuse_module.parameters(recurse)
if self.config.train_image_encoder:
yield from self.vision_encoder.parameters(recurse)
elif self.config.type == 'clip_fusion':
yield from self.clip_fusion_module.parameters(recurse)
if self.config.train_image_encoder:
yield from self.vision_encoder.parameters(recurse)
else:
raise NotImplementedError