From b68c3ef734162ba6b1ec33d8c8ad0853a2d55774 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Wed, 21 Feb 2024 21:30:26 -0700 Subject: [PATCH] Added te aug adapter --- extensions_built_in/sd_trainer/SDTrainer.py | 15 ++ jobs/process/TrainVAEProcess.py | 3 +- toolkit/custom_adapter.py | 41 +++- toolkit/models/te_aug_adapter.py | 253 ++++++++++++++++++++ toolkit/stable_diffusion_model.py | 6 + 5 files changed, 310 insertions(+), 8 deletions(-) create mode 100644 toolkit/models/te_aug_adapter.py diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index 9e713475..1952174f 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -930,6 +930,9 @@ class SDTrainer(BaseSDTrainProcess): if self.adapter and isinstance(self.adapter, ClipVisionAdapter): grad_on_text_encoder = True + if self.adapter_config.type == 'te_augmenter': + grad_on_text_encoder = True + # have a blank network so we can wrap it in a context and set multipliers without checking every time if self.network is not None: network = self.network @@ -1045,6 +1048,8 @@ class SDTrainer(BaseSDTrainProcess): unconditional_embeds = None if grad_on_text_encoder: with torch.set_grad_enabled(True): + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False conditional_embeds = self.sd.encode_prompt( conditioned_prompts, prompt_2, dropout_prob=self.train_config.prompt_dropout_prob, @@ -1053,6 +1058,8 @@ class SDTrainer(BaseSDTrainProcess): dtype=dtype) if self.train_config.do_cfg: + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = True # todo only do one and repeat it unconditional_embeds = self.sd.encode_prompt( self.batch_negative_prompt, @@ -1061,6 +1068,8 @@ class SDTrainer(BaseSDTrainProcess): long_prompts=self.do_long_prompts).to( self.device_torch, dtype=dtype) + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False else: with torch.set_grad_enabled(False): # make sure it is in eval mode @@ -1069,6 +1078,8 @@ class SDTrainer(BaseSDTrainProcess): te.eval() else: self.sd.text_encoder.eval() + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False conditional_embeds = self.sd.encode_prompt( conditioned_prompts, prompt_2, dropout_prob=self.train_config.prompt_dropout_prob, @@ -1076,12 +1087,16 @@ class SDTrainer(BaseSDTrainProcess): self.device_torch, dtype=dtype) if self.train_config.do_cfg: + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = True unconditional_embeds = self.sd.encode_prompt( self.batch_negative_prompt, dropout_prob=self.train_config.prompt_dropout_prob, long_prompts=self.do_long_prompts).to( self.device_torch, dtype=dtype) + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False # detach the embeddings conditional_embeds = conditional_embeds.detach() diff --git a/jobs/process/TrainVAEProcess.py b/jobs/process/TrainVAEProcess.py index a7da8160..00b098c6 100644 --- a/jobs/process/TrainVAEProcess.py +++ b/jobs/process/TrainVAEProcess.py @@ -477,7 +477,8 @@ class TrainVAEProcess(BaseTrainProcess): if self.use_critic: loss_string += f" crD: {critic_d_loss:.2e}" - if self.optimizer_type.startswith('dadaptation'): + if self.optimizer_type.startswith('dadaptation') or \ + self.optimizer_type.lower().startswith('prodigy'): learning_rate = ( optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py index 694869d8..18780d87 100644 --- a/toolkit/custom_adapter.py +++ b/toolkit/custom_adapter.py @@ -10,6 +10,7 @@ from toolkit.models.clip_fusion import CLIPFusionModule from toolkit.models.clip_pre_processor import CLIPImagePreProcessor from toolkit.models.ilora import InstantLoRAModule from toolkit.models.te_adapter import TEAdapter +from toolkit.models.te_aug_adapter import TEAugAdapter from toolkit.models.vd_adapter import VisionDirectAdapter from toolkit.paths import REPOS_ROOT from toolkit.photomaker import PhotoMakerIDEncoder, FuseModule, PhotoMakerCLIPEncoder @@ -60,6 +61,7 @@ class CustomAdapter(torch.nn.Module): self.current_scale = 1.0 self.is_active = True self.flag_word = "fla9wor0" + self.is_unconditional_run = False self.vision_encoder: Union[PhotoMakerCLIPEncoder, CLIPVisionModelWithProjection] = None @@ -83,6 +85,7 @@ class CustomAdapter(torch.nn.Module): self.te: Union[T5EncoderModel, CLIPTextModel] = None self.tokenizer: CLIPTokenizer = None self.te_adapter: TEAdapter = None + self.te_augmenter: TEAugAdapter = None self.vd_adapter: VisionDirectAdapter = None self.conditional_embeds: Optional[torch.Tensor] = None self.unconditional_embeds: Optional[torch.Tensor] = None @@ -149,6 +152,8 @@ class CustomAdapter(torch.nn.Module): raise ValueError(f"unknown text encoder arch: {self.config.text_encoder_arch}") self.te_adapter = TEAdapter(self, self.sd_ref(), self.te, self.tokenizer) + elif self.adapter_type == 'te_augmenter': + self.te_augmenter = TEAugAdapter(self, self.sd_ref()) elif self.adapter_type == 'vision_direct': self.vd_adapter = VisionDirectAdapter(self, self.sd_ref(), self.vision_encoder) else: @@ -269,9 +274,13 @@ class CustomAdapter(torch.nn.Module): preprocessor_input_size = self.vision_encoder.config.image_size * 2 # update the preprocessor so images come in at the right size - self.image_processor.size['shortest_edge'] = preprocessor_input_size - self.image_processor.crop_size['height'] = preprocessor_input_size - self.image_processor.crop_size['width'] = preprocessor_input_size + if 'height' in self.image_processor.size: + self.image_processor.size['height'] = preprocessor_input_size + self.image_processor.size['width'] = preprocessor_input_size + elif hasattr(self.image_processor, 'crop_size'): + self.image_processor.size['shortest_edge'] = preprocessor_input_size + self.image_processor.crop_size['height'] = preprocessor_input_size + self.image_processor.crop_size['width'] = preprocessor_input_size if self.config.image_encoder_arch == 'clip+': # self.image_processor.config @@ -340,6 +349,9 @@ class CustomAdapter(torch.nn.Module): if 'te_adapter' in state_dict: self.te_adapter.load_state_dict(state_dict['te_adapter'], strict=strict) + if 'te_augmenter' in state_dict: + self.te_augmenter.load_state_dict(state_dict['te_augmenter'], strict=strict) + if 'vd_adapter' in state_dict: self.vd_adapter.load_state_dict(state_dict['vd_adapter'], strict=strict) if 'dvadapter' in state_dict: @@ -378,6 +390,11 @@ class CustomAdapter(torch.nn.Module): elif self.adapter_type == 'text_encoder': state_dict["te_adapter"] = self.te_adapter.state_dict() return state_dict + elif self.adapter_type == 'te_augmenter': + if self.config.train_image_encoder: + state_dict["vision_encoder"] = self.vision_encoder.state_dict() + state_dict["te_augmenter"] = self.te_augmenter.state_dict() + return state_dict elif self.adapter_type == 'vision_direct': state_dict["dvadapter"] = self.vd_adapter.state_dict() if self.config.train_image_encoder: @@ -647,7 +664,7 @@ class CustomAdapter(torch.nn.Module): has_been_preprocessed=False, quad_count=4, ) -> PromptEmbeds: - if self.adapter_type == 'ilora' or self.adapter_type == 'vision_direct': + if self.adapter_type == 'ilora' or self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter': if tensors_0_1 is None: tensors_0_1 = self.get_empty_clip_image(1) has_been_preprocessed = True @@ -675,7 +692,7 @@ class CustomAdapter(torch.nn.Module): clip_image = tensors_0_1 batch_size = clip_image.shape[0] - if self.adapter_type == 'vision_direct': + if self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter': # add an unconditional so we can save it unconditional = self.get_empty_clip_image(batch_size).to( clip_image.device, dtype=clip_image.dtype @@ -730,7 +747,7 @@ class CustomAdapter(torch.nn.Module): img_embeds = img_embeds.detach() self.ilora_module(img_embeds) - if self.adapter_type == 'vision_direct': + if self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter': with torch.set_grad_enabled(is_training): if is_training and self.config.train_image_encoder: self.vision_encoder.train() @@ -754,8 +771,14 @@ class CustomAdapter(torch.nn.Module): if not is_training or not self.config.train_image_encoder: clip_image_embeds = clip_image_embeds.detach() + if self.adapter_type == 'te_augmenter': + clip_image_embeds = self.te_augmenter(clip_image_embeds) + # save them to the conditional and unconditional - self.unconditional_embeds, self.conditional_embeds = clip_image_embeds.chunk(2, dim=0) + try: + self.unconditional_embeds, self.conditional_embeds = clip_image_embeds.chunk(2, dim=0) + except ValueError: + raise ValueError(f"could not split the clip image embeds into 2. Got shape: {clip_image_embeds.shape}") def parameters(self, recurse: bool = True) -> Iterator[Parameter]: if self.config.train_only_image_encoder: @@ -781,6 +804,10 @@ class CustomAdapter(torch.nn.Module): yield from attn_processor.parameters(recurse) if self.config.train_image_encoder: yield from self.vision_encoder.parameters(recurse) + elif self.config.type == 'te_augmenter': + yield from self.te_augmenter.parameters(recurse) + if self.config.train_image_encoder: + yield from self.vision_encoder.parameters(recurse) else: raise NotImplementedError diff --git a/toolkit/models/te_aug_adapter.py b/toolkit/models/te_aug_adapter.py new file mode 100644 index 00000000..02cbbec1 --- /dev/null +++ b/toolkit/models/te_aug_adapter.py @@ -0,0 +1,253 @@ +import sys + +import torch +import torch.nn as nn +import torch.nn.functional as F +import weakref +from typing import Union, TYPE_CHECKING, Optional, Tuple + +from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer +from transformers.models.clip.modeling_clip import CLIPEncoder, CLIPAttention + +from toolkit.models.zipper_resampler import ZipperResampler, ZipperModule +from toolkit.paths import REPOS_ROOT +from toolkit.resampler import Resampler + +sys.path.append(REPOS_ROOT) + +from ipadapter.ip_adapter.attention_processor import AttnProcessor2_0 + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion + from toolkit.custom_adapter import CustomAdapter + + +class TEAugAdapterCLIPAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, attn_module: 'CLIPAttention', adapter: 'TEAugAdapter'): + super().__init__() + self.adapter_ref: weakref.ref = weakref.ref(adapter) + self.attn_module_ref: weakref.ref = weakref.ref(attn_module) + self.k_proj_adapter = nn.Linear(attn_module.embed_dim, attn_module.embed_dim) + self.v_proj_adapter = nn.Linear(attn_module.embed_dim, attn_module.embed_dim) + # copy the weights from the original module + self.k_proj_adapter.weight.data = attn_module.k_proj.weight.data.clone() * 0.01 + self.v_proj_adapter.weight.data = attn_module.v_proj.weight.data.clone() * 0.01 + #reset the bias + self.k_proj_adapter.bias.data = attn_module.k_proj.bias.data.clone() * 0.001 + self.v_proj_adapter.bias.data = attn_module.v_proj.bias.data.clone() * 0.001 + + self.zipper = ZipperModule( + in_size=attn_module.embed_dim, + in_tokens=77 * 2, + out_size=attn_module.embed_dim, + out_tokens=77, + hidden_size=attn_module.embed_dim, + hidden_tokens=77, + ) + # self.k_proj_adapter.weight.data = torch.zeros_like(attn_module.k_proj.weight.data) + # self.v_proj_adapter.weight.data = torch.zeros_like(attn_module.v_proj.weight.data) + # #reset the bias + # self.k_proj_adapter.bias.data = torch.zeros_like(attn_module.k_proj.bias.data) + # self.v_proj_adapter.bias.data = torch.zeros_like(attn_module.v_proj.bias.data) + + # replace the original forward with our forward + self.original_forward = attn_module.forward + attn_module.forward = self.forward + + + @property + def is_active(self): + return self.adapter_ref().is_active + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + attn_module = self.attn_module_ref() + + bsz, tgt_len, embed_dim = hidden_states.size() + + # get query proj + query_states = attn_module.q_proj(hidden_states) * attn_module.scale + key_states = attn_module._shape(attn_module.k_proj(hidden_states), -1, bsz) + value_states = attn_module._shape(attn_module.v_proj(hidden_states), -1, bsz) + + proj_shape = (bsz * attn_module.num_heads, -1, attn_module.head_dim) + query_states = attn_module._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + + src_len = key_states.size(1) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + if attn_weights.size() != (bsz * attn_module.num_heads, tgt_len, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * attn_module.num_heads, tgt_len, src_len)}, but is" + f" {attn_weights.size()}" + ) + + # apply the causal_attention_mask first + if causal_attention_mask is not None: + if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" + f" {causal_attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, attn_module.num_heads, tgt_len, src_len) + causal_attention_mask + attn_weights = attn_weights.view(bsz * attn_module.num_heads, tgt_len, src_len) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights.view(bsz, attn_module.num_heads, tgt_len, src_len) + attention_mask + attn_weights = attn_weights.view(bsz * attn_module.num_heads, tgt_len, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if output_attentions: + # this operation is a bit akward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, attn_module.num_heads, tgt_len, src_len) + attn_weights = attn_weights_reshaped.view(bsz * attn_module.num_heads, tgt_len, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=attn_module.dropout, training=self.training) + + attn_output = torch.bmm(attn_probs, value_states) + + if attn_output.size() != (bsz * attn_module.num_heads, tgt_len, attn_module.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, attn_module.num_heads, tgt_len, attn_module.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, attn_module.num_heads, tgt_len, attn_module.head_dim) + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) + + adapter: 'CustomAdapter' = self.adapter_ref().adapter_ref() + if self.adapter_ref().is_active and adapter.conditional_embeds is not None: + # apply the adapter + + if adapter.is_unconditional_run: + embeds = adapter.unconditional_embeds + else: + embeds = adapter.conditional_embeds + # if the shape is not the same on batch, we are doing cfg and need to concat unconditional as well + if embeds.size(0) != bsz: + embeds = torch.cat([adapter.unconditional_embeds, embeds], dim=0) + + key_states_raw = self.k_proj_adapter(embeds) + key_states = attn_module._shape(key_states_raw, -1, bsz) + value_states_raw = self.v_proj_adapter(embeds) + value_states = attn_module._shape(value_states_raw, -1, bsz) + key_states = key_states.view(*proj_shape) + value_states = value_states.view(*proj_shape) + attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_probs = nn.functional.dropout(attn_weights, p=attn_module.dropout, training=self.training) + attn_output_adapter = torch.bmm(attn_probs, value_states) + + if attn_output_adapter.size() != (bsz * attn_module.num_heads, tgt_len, attn_module.head_dim): + raise ValueError( + f"`attn_output_adapter` should be of size {(bsz, attn_module.num_heads, tgt_len, attn_module.head_dim)}, but is" + f" {attn_output_adapter.size()}" + ) + + attn_output_adapter = attn_output_adapter.view(bsz, attn_module.num_heads, tgt_len, attn_module.head_dim) + attn_output_adapter = attn_output_adapter.transpose(1, 2) + attn_output_adapter = attn_output_adapter.reshape(bsz, tgt_len, embed_dim) + + attn_output_adapter = self.zipper(torch.cat([attn_output_adapter, attn_output], dim=1)) + + # attn_output_adapter = attn_module.out_proj(attn_output_adapter) + attn_output = attn_output + attn_output_adapter + + attn_output = attn_module.out_proj(attn_output) + + return attn_output, attn_weights_reshaped + +class TEAugAdapter(torch.nn.Module): + def __init__( + self, + adapter: 'CustomAdapter', + sd: 'StableDiffusion', + ): + super(TEAugAdapter, self).__init__() + self.adapter_ref: weakref.ref = weakref.ref(adapter) + self.sd_ref: weakref.ref = weakref.ref(sd) + + if isinstance(sd.text_encoder, list): + raise ValueError("Dual text encoders is not yet supported") + + # dim will come from text encoder + # dim = sd.unet.config['cross_attention_dim'] + text_encoder: CLIPTextModel = sd.text_encoder + dim = text_encoder.config.hidden_size + + clip_encoder: CLIPEncoder = text_encoder.text_model.encoder + # dim = clip_encoder.layers[-1].self_attn + + if hasattr(adapter.vision_encoder.config, 'hidden_sizes'): + embedding_dim = adapter.vision_encoder.config.hidden_sizes[-1] + else: + embedding_dim = adapter.vision_encoder.config.hidden_size + + image_encoder_state_dict = adapter.vision_encoder.state_dict() + # max_seq_len = CLIP tokens + CLS token + in_tokens = 257 + if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict: + # clip + in_tokens = int(image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0]) + + if adapter.config.image_encoder_arch.startswith('convnext'): + in_tokens = 16 * 16 + embedding_dim = adapter.vision_encoder.config.hidden_sizes[-1] + + out_tokens = adapter.config.num_tokens if adapter.config.num_tokens > 0 else in_tokens + self.image_proj_model = ZipperModule( + in_size=embedding_dim, + in_tokens=in_tokens, + out_size=dim, + out_tokens=out_tokens, + hidden_size=dim, + hidden_tokens=out_tokens, + ) + # init adapter modules + attn_procs = {} + for idx, layer in enumerate(clip_encoder.layers): + name = f"clip_attention.{idx}" + attn_procs[name] = TEAugAdapterCLIPAttention( + layer.self_attn, + self + ) + + self.adapter_modules = torch.nn.ModuleList(list(attn_procs.values())) + + # make a getter to see if is active + @property + def is_active(self): + return self.adapter_ref().is_active + + + def forward(self, input): + # # apply the adapter + input = self.image_proj_model(input) + # self.embeds = input + return input diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 5694a146..cb3433ad 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -585,11 +585,17 @@ class StableDiffusion: ) # encode the prompt ourselves so we can do fun stuff with embeddings + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False conditional_embeds = self.encode_prompt(gen_config.prompt, gen_config.prompt_2, force_all=True) + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = True unconditional_embeds = self.encode_prompt( gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True ) + if isinstance(self.adapter, CustomAdapter): + self.adapter.is_unconditional_run = False # allow any manipulations to take place to embeddings gen_config.post_process_embeddings(