mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-29 02:31:17 +00:00
Added an experimental clip fusion model that is showing promise for embedding concepts
This commit is contained in:
@@ -5,6 +5,7 @@ from PIL import Image
|
||||
from torch.nn import Parameter
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from toolkit.models.clip_fusion import CLIPFusionModule
|
||||
from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
from toolkit.photomaker import PhotoMakerIDEncoder, FuseModule, PhotoMakerCLIPEncoder
|
||||
@@ -72,50 +73,66 @@ class CustomAdapter(torch.nn.Module):
|
||||
# add for dataloader
|
||||
self.clip_image_processor = self.image_processor
|
||||
|
||||
self.clip_fusion_module: CLIPFusionModule = None
|
||||
|
||||
self.setup_adapter()
|
||||
|
||||
# try to load from our name_or_path
|
||||
if self.config.name_or_path is not None and self.config.name_or_path.endswith('.bin'):
|
||||
self.load_state_dict(torch.load(self.config.name_or_path, map_location=self.device), strict=False)
|
||||
|
||||
# add the trigger word to the tokenizer
|
||||
if isinstance(self.sd_ref().tokenizer, list):
|
||||
for tokenizer in self.sd_ref().tokenizer:
|
||||
tokenizer.add_tokens([self.flag_word], special_tokens=True)
|
||||
else:
|
||||
self.sd_ref().tokenizer.add_tokens([self.flag_word], special_tokens=True)
|
||||
if self.adapter_type == 'photo_maker':
|
||||
# try to load from our name_or_path
|
||||
if self.config.name_or_path is not None and self.config.name_or_path.endswith('.bin'):
|
||||
self.load_state_dict(torch.load(self.config.name_or_path, map_location=self.device), strict=False)
|
||||
# add the trigger word to the tokenizer
|
||||
if isinstance(self.sd_ref().tokenizer, list):
|
||||
for tokenizer in self.sd_ref().tokenizer:
|
||||
tokenizer.add_tokens([self.flag_word], special_tokens=True)
|
||||
else:
|
||||
self.sd_ref().tokenizer.add_tokens([self.flag_word], special_tokens=True)
|
||||
|
||||
def setup_adapter(self):
|
||||
if self.adapter_type == 'photo_maker':
|
||||
sd = self.sd_ref()
|
||||
embed_dim = sd.unet.config['cross_attention_dim']
|
||||
self.fuse_module = FuseModule(embed_dim)
|
||||
elif self.adapter_type == 'clip_fusion':
|
||||
sd = self.sd_ref()
|
||||
embed_dim = sd.unet.config['cross_attention_dim']
|
||||
|
||||
vision_tokens = ((self.vision_encoder.config.image_size // self.vision_encoder.config.patch_size) ** 2)
|
||||
if self.config.image_encoder_arch == 'clip':
|
||||
vision_tokens = vision_tokens + 1
|
||||
self.clip_fusion_module = CLIPFusionModule(
|
||||
text_hidden_size=embed_dim,
|
||||
text_tokens=77,
|
||||
vision_hidden_size=self.vision_encoder.config.hidden_size,
|
||||
vision_tokens=vision_tokens
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown adapter type: {self.adapter_type}")
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.adapter_type == 'photo_maker':
|
||||
id_pixel_values = args[0]
|
||||
prompt_embeds: PromptEmbeds = args[1]
|
||||
class_tokens_mask = args[2]
|
||||
|
||||
grads_on_image_encoder = self.config.train_image_encoder and torch.is_grad_enabled()
|
||||
|
||||
with torch.set_grad_enabled(grads_on_image_encoder):
|
||||
id_embeds = self.vision_encoder(self, id_pixel_values, do_projection2=False)
|
||||
|
||||
if not grads_on_image_encoder:
|
||||
id_embeds = id_embeds.detach()
|
||||
|
||||
prompt_embeds = prompt_embeds.detach()
|
||||
|
||||
updated_prompt_embeds = self.fuse_module(
|
||||
prompt_embeds, id_embeds, class_tokens_mask
|
||||
)
|
||||
|
||||
return updated_prompt_embeds
|
||||
else:
|
||||
raise NotImplementedError
|
||||
# dont think this is used
|
||||
# if self.adapter_type == 'photo_maker':
|
||||
# id_pixel_values = args[0]
|
||||
# prompt_embeds: PromptEmbeds = args[1]
|
||||
# class_tokens_mask = args[2]
|
||||
#
|
||||
# grads_on_image_encoder = self.config.train_image_encoder and torch.is_grad_enabled()
|
||||
#
|
||||
# with torch.set_grad_enabled(grads_on_image_encoder):
|
||||
# id_embeds = self.vision_encoder(self, id_pixel_values, do_projection2=False)
|
||||
#
|
||||
# if not grads_on_image_encoder:
|
||||
# id_embeds = id_embeds.detach()
|
||||
#
|
||||
# prompt_embeds = prompt_embeds.detach()
|
||||
#
|
||||
# updated_prompt_embeds = self.fuse_module(
|
||||
# prompt_embeds, id_embeds, class_tokens_mask
|
||||
# )
|
||||
#
|
||||
# return updated_prompt_embeds
|
||||
# else:
|
||||
raise NotImplementedError
|
||||
|
||||
def setup_clip(self):
|
||||
adapter_config = self.config
|
||||
@@ -226,7 +243,9 @@ class CustomAdapter(torch.nn.Module):
|
||||
# self.sd_ref().pipeline.load_lora_weights(state_dict["lora_weights"], adapter_name="photomaker")
|
||||
# self.sd_ref().pipeline.fuse_lora()
|
||||
pass
|
||||
if 'id_encoder' in state_dict and self.adapter_type == 'photo_maker':
|
||||
if 'clip_fusion' in state_dict:
|
||||
self.clip_fusion_module.load_state_dict(state_dict['clip_fusion'], strict=strict)
|
||||
if 'id_encoder' in state_dict and (self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion'):
|
||||
self.vision_encoder.load_state_dict(state_dict['id_encoder'], strict=strict)
|
||||
# check to see if the fuse weights are there
|
||||
fuse_weights = {}
|
||||
@@ -235,7 +254,31 @@ class CustomAdapter(torch.nn.Module):
|
||||
k = k.replace('fuse_module.', '')
|
||||
fuse_weights[k] = v
|
||||
if len(fuse_weights) > 0:
|
||||
self.fuse_module.load_state_dict(fuse_weights, strict=strict)
|
||||
try:
|
||||
self.fuse_module.load_state_dict(fuse_weights, strict=strict)
|
||||
except Exception as e:
|
||||
|
||||
print(e)
|
||||
# force load it
|
||||
print(f"force loading fuse module as it did not match")
|
||||
current_state_dict = self.fuse_module.state_dict()
|
||||
for k, v in fuse_weights.items():
|
||||
if len(v.shape) == 1:
|
||||
current_state_dict[k] = v[:current_state_dict[k].shape[0]]
|
||||
elif len(v.shape) == 2:
|
||||
current_state_dict[k] = v[:current_state_dict[k].shape[0], :current_state_dict[k].shape[1]]
|
||||
elif len(v.shape) == 3:
|
||||
current_state_dict[k] = v[:current_state_dict[k].shape[0], :current_state_dict[k].shape[1],
|
||||
:current_state_dict[k].shape[2]]
|
||||
elif len(v.shape) == 4:
|
||||
current_state_dict[k] = v[:current_state_dict[k].shape[0], :current_state_dict[k].shape[1],
|
||||
:current_state_dict[k].shape[2], :current_state_dict[k].shape[3]]
|
||||
else:
|
||||
raise ValueError(f"unknown shape: {v.shape}")
|
||||
self.fuse_module.load_state_dict(current_state_dict, strict=strict)
|
||||
|
||||
if 'vision_encoder' in state_dict:
|
||||
self.vision_encoder.load_state_dict(state_dict['vision_encoder'], strict=strict)
|
||||
|
||||
if 'fuse_module' in state_dict:
|
||||
self.fuse_module.load_state_dict(state_dict['fuse_module'], strict=strict)
|
||||
@@ -252,6 +295,12 @@ class CustomAdapter(torch.nn.Module):
|
||||
|
||||
# todo save LoRA
|
||||
return state_dict
|
||||
|
||||
elif self.adapter_type == 'clip_fusion':
|
||||
if self.config.train_image_encoder:
|
||||
state_dict["vision_encoder"] = self.vision_encoder.state_dict()
|
||||
state_dict["clip_fusion"] = self.clip_fusion_module.state_dict()
|
||||
return state_dict
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -260,7 +309,9 @@ class CustomAdapter(torch.nn.Module):
|
||||
prompt: Union[List[str], str],
|
||||
is_unconditional: bool = False,
|
||||
):
|
||||
if self.adapter_type == 'photo_maker':
|
||||
if self.adapter_type == 'clip_fusion':
|
||||
return prompt
|
||||
elif self.adapter_type == 'photo_maker':
|
||||
if is_unconditional:
|
||||
return prompt
|
||||
else:
|
||||
@@ -310,7 +361,8 @@ class CustomAdapter(torch.nn.Module):
|
||||
# add the first one to the front of the prompt
|
||||
tokened_prompt = self.config.class_names[0] + ' ' + self.flag_word + ' ' + prompt
|
||||
our_class = self.config.class_names[0]
|
||||
prompt = " ".join([self.config.class_names[0] for _ in range(self.num_control_images)]) + ' ' + prompt
|
||||
prompt = " ".join(
|
||||
[self.config.class_names[0] for _ in range(self.num_control_images)]) + ' ' + prompt
|
||||
|
||||
# add the prompt to the list
|
||||
new_prompt_list.append(prompt)
|
||||
@@ -329,8 +381,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
|
||||
class_token = tokenized_prompt[flag_idx - 1]
|
||||
|
||||
|
||||
boolean_mask = torch.zeros(flag_idx-1, dtype=torch.bool)
|
||||
boolean_mask = torch.zeros(flag_idx - 1, dtype=torch.bool)
|
||||
boolean_mask = torch.cat((boolean_mask, torch.ones(self.num_control_images, dtype=torch.bool)))
|
||||
boolean_mask = boolean_mask.to(self.device)
|
||||
# zero pad it to 77
|
||||
@@ -357,62 +408,96 @@ class CustomAdapter(torch.nn.Module):
|
||||
has_been_preprocessed=False,
|
||||
is_unconditional=False
|
||||
) -> PromptEmbeds:
|
||||
if self.adapter_type == 'photo_maker':
|
||||
if self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion':
|
||||
if is_unconditional:
|
||||
# we dont condition the negative embeds for photo maker
|
||||
return prompt_embeds
|
||||
with torch.no_grad():
|
||||
# on training the clip image is created in the dataloader
|
||||
if not has_been_preprocessed:
|
||||
# tensors should be 0-1
|
||||
if tensors_0_1.ndim == 3:
|
||||
tensors_0_1 = tensors_0_1.unsqueeze(0)
|
||||
# training tensors are 0 - 1
|
||||
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
|
||||
# if images are out of this range throw error
|
||||
if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
|
||||
raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
|
||||
tensors_0_1.min(), tensors_0_1.max()
|
||||
))
|
||||
clip_image = self.image_processor(
|
||||
images=tensors_0_1,
|
||||
return_tensors="pt",
|
||||
do_resize=True,
|
||||
do_rescale=False,
|
||||
).pixel_values
|
||||
else:
|
||||
clip_image = tensors_0_1
|
||||
clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach()
|
||||
with torch.no_grad():
|
||||
# on training the clip image is created in the dataloader
|
||||
if not has_been_preprocessed:
|
||||
# tensors should be 0-1
|
||||
if tensors_0_1.ndim == 3:
|
||||
tensors_0_1 = tensors_0_1.unsqueeze(0)
|
||||
# training tensors are 0 - 1
|
||||
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
|
||||
# if images are out of this range throw error
|
||||
if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
|
||||
raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
|
||||
tensors_0_1.min(), tensors_0_1.max()
|
||||
))
|
||||
clip_image = self.image_processor(
|
||||
images=tensors_0_1,
|
||||
return_tensors="pt",
|
||||
do_resize=True,
|
||||
do_rescale=False,
|
||||
).pixel_values
|
||||
else:
|
||||
clip_image = tensors_0_1
|
||||
clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach()
|
||||
|
||||
if self.adapter_type == 'photo_maker':
|
||||
# Embeddings need to be (b, num_inputs, c, h, w) for now, just put 1 input image
|
||||
clip_image = clip_image.unsqueeze(1)
|
||||
with torch.set_grad_enabled(is_training):
|
||||
if is_training and self.config.train_image_encoder:
|
||||
self.vision_encoder.train()
|
||||
clip_image = clip_image.requires_grad_(True)
|
||||
id_embeds = self.vision_encoder(
|
||||
clip_image,
|
||||
do_projection2=isinstance(self.sd_ref().text_encoder, list),
|
||||
)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
self.vision_encoder.eval()
|
||||
id_embeds = self.vision_encoder(
|
||||
clip_image, do_projection2=isinstance(self.sd_ref().text_encoder, list)
|
||||
).detach()
|
||||
|
||||
prompt_embeds.text_embeds = self.fuse_module(
|
||||
prompt_embeds.text_embeds,
|
||||
id_embeds,
|
||||
self.token_mask
|
||||
)
|
||||
return prompt_embeds
|
||||
elif self.adapter_type == 'clip_fusion':
|
||||
with torch.set_grad_enabled(is_training):
|
||||
if is_training and self.config.train_image_encoder:
|
||||
self.vision_encoder.train()
|
||||
clip_image = clip_image.requires_grad_(True)
|
||||
id_embeds = self.vision_encoder(
|
||||
clip_image,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
self.vision_encoder.eval()
|
||||
id_embeds = self.vision_encoder(
|
||||
clip_image, output_hidden_states=True
|
||||
)
|
||||
|
||||
img_embeds = id_embeds['last_hidden_state']
|
||||
|
||||
if not is_training or not self.config.train_image_encoder:
|
||||
img_embeds = img_embeds.detach()
|
||||
|
||||
prompt_embeds.text_embeds = self.clip_fusion_module(
|
||||
prompt_embeds.text_embeds,
|
||||
img_embeds
|
||||
)
|
||||
return prompt_embeds
|
||||
|
||||
|
||||
# Embeddings need to be (b, num_inputs, c, h, w) for now, just put 1 input image
|
||||
clip_image = clip_image.unsqueeze(1)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
if is_training and self.config.train_image_encoder:
|
||||
self.vision_encoder.train()
|
||||
clip_image = clip_image.requires_grad_(True)
|
||||
id_embeds = self.vision_encoder(
|
||||
clip_image,
|
||||
do_projection2=isinstance(self.sd_ref().text_encoder, list),
|
||||
)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
self.vision_encoder.eval()
|
||||
id_embeds = self.vision_encoder(
|
||||
clip_image, do_projection2=isinstance(self.sd_ref().text_encoder, list)
|
||||
).detach()
|
||||
|
||||
prompt_embeds.text_embeds = self.fuse_module(
|
||||
prompt_embeds.text_embeds,
|
||||
id_embeds,
|
||||
self.token_mask
|
||||
)
|
||||
|
||||
return prompt_embeds
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
||||
if self.config.type == 'photo_maker':
|
||||
yield from self.fuse_module.parameters(recurse)
|
||||
if self.config.train_image_encoder:
|
||||
yield from self.vision_encoder.parameters(recurse)
|
||||
elif self.config.type == 'clip_fusion':
|
||||
yield from self.clip_fusion_module.parameters(recurse)
|
||||
if self.config.train_image_encoder:
|
||||
yield from self.vision_encoder.parameters(recurse)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
Reference in New Issue
Block a user