Added te aug adapter

This commit is contained in:
Jaret Burkett
2024-02-21 21:30:26 -07:00
parent 49c41e6a5f
commit b68c3ef734
5 changed files with 310 additions and 8 deletions

View File

@@ -930,6 +930,9 @@ class SDTrainer(BaseSDTrainProcess):
if self.adapter and isinstance(self.adapter, ClipVisionAdapter):
grad_on_text_encoder = True
if self.adapter_config.type == 'te_augmenter':
grad_on_text_encoder = True
# have a blank network so we can wrap it in a context and set multipliers without checking every time
if self.network is not None:
network = self.network
@@ -1045,6 +1048,8 @@ class SDTrainer(BaseSDTrainProcess):
unconditional_embeds = None
if grad_on_text_encoder:
with torch.set_grad_enabled(True):
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False
conditional_embeds = self.sd.encode_prompt(
conditioned_prompts, prompt_2,
dropout_prob=self.train_config.prompt_dropout_prob,
@@ -1053,6 +1058,8 @@ class SDTrainer(BaseSDTrainProcess):
dtype=dtype)
if self.train_config.do_cfg:
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = True
# todo only do one and repeat it
unconditional_embeds = self.sd.encode_prompt(
self.batch_negative_prompt,
@@ -1061,6 +1068,8 @@ class SDTrainer(BaseSDTrainProcess):
long_prompts=self.do_long_prompts).to(
self.device_torch,
dtype=dtype)
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False
else:
with torch.set_grad_enabled(False):
# make sure it is in eval mode
@@ -1069,6 +1078,8 @@ class SDTrainer(BaseSDTrainProcess):
te.eval()
else:
self.sd.text_encoder.eval()
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False
conditional_embeds = self.sd.encode_prompt(
conditioned_prompts, prompt_2,
dropout_prob=self.train_config.prompt_dropout_prob,
@@ -1076,12 +1087,16 @@ class SDTrainer(BaseSDTrainProcess):
self.device_torch,
dtype=dtype)
if self.train_config.do_cfg:
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = True
unconditional_embeds = self.sd.encode_prompt(
self.batch_negative_prompt,
dropout_prob=self.train_config.prompt_dropout_prob,
long_prompts=self.do_long_prompts).to(
self.device_torch,
dtype=dtype)
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False
# detach the embeddings
conditional_embeds = conditional_embeds.detach()

View File

@@ -477,7 +477,8 @@ class TrainVAEProcess(BaseTrainProcess):
if self.use_critic:
loss_string += f" crD: {critic_d_loss:.2e}"
if self.optimizer_type.startswith('dadaptation'):
if self.optimizer_type.startswith('dadaptation') or \
self.optimizer_type.lower().startswith('prodigy'):
learning_rate = (
optimizer.param_groups[0]["d"] *
optimizer.param_groups[0]["lr"]

View File

@@ -10,6 +10,7 @@ from toolkit.models.clip_fusion import CLIPFusionModule
from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
from toolkit.models.ilora import InstantLoRAModule
from toolkit.models.te_adapter import TEAdapter
from toolkit.models.te_aug_adapter import TEAugAdapter
from toolkit.models.vd_adapter import VisionDirectAdapter
from toolkit.paths import REPOS_ROOT
from toolkit.photomaker import PhotoMakerIDEncoder, FuseModule, PhotoMakerCLIPEncoder
@@ -60,6 +61,7 @@ class CustomAdapter(torch.nn.Module):
self.current_scale = 1.0
self.is_active = True
self.flag_word = "fla9wor0"
self.is_unconditional_run = False
self.vision_encoder: Union[PhotoMakerCLIPEncoder, CLIPVisionModelWithProjection] = None
@@ -83,6 +85,7 @@ class CustomAdapter(torch.nn.Module):
self.te: Union[T5EncoderModel, CLIPTextModel] = None
self.tokenizer: CLIPTokenizer = None
self.te_adapter: TEAdapter = None
self.te_augmenter: TEAugAdapter = None
self.vd_adapter: VisionDirectAdapter = None
self.conditional_embeds: Optional[torch.Tensor] = None
self.unconditional_embeds: Optional[torch.Tensor] = None
@@ -149,6 +152,8 @@ class CustomAdapter(torch.nn.Module):
raise ValueError(f"unknown text encoder arch: {self.config.text_encoder_arch}")
self.te_adapter = TEAdapter(self, self.sd_ref(), self.te, self.tokenizer)
elif self.adapter_type == 'te_augmenter':
self.te_augmenter = TEAugAdapter(self, self.sd_ref())
elif self.adapter_type == 'vision_direct':
self.vd_adapter = VisionDirectAdapter(self, self.sd_ref(), self.vision_encoder)
else:
@@ -269,9 +274,13 @@ class CustomAdapter(torch.nn.Module):
preprocessor_input_size = self.vision_encoder.config.image_size * 2
# update the preprocessor so images come in at the right size
self.image_processor.size['shortest_edge'] = preprocessor_input_size
self.image_processor.crop_size['height'] = preprocessor_input_size
self.image_processor.crop_size['width'] = preprocessor_input_size
if 'height' in self.image_processor.size:
self.image_processor.size['height'] = preprocessor_input_size
self.image_processor.size['width'] = preprocessor_input_size
elif hasattr(self.image_processor, 'crop_size'):
self.image_processor.size['shortest_edge'] = preprocessor_input_size
self.image_processor.crop_size['height'] = preprocessor_input_size
self.image_processor.crop_size['width'] = preprocessor_input_size
if self.config.image_encoder_arch == 'clip+':
# self.image_processor.config
@@ -340,6 +349,9 @@ class CustomAdapter(torch.nn.Module):
if 'te_adapter' in state_dict:
self.te_adapter.load_state_dict(state_dict['te_adapter'], strict=strict)
if 'te_augmenter' in state_dict:
self.te_augmenter.load_state_dict(state_dict['te_augmenter'], strict=strict)
if 'vd_adapter' in state_dict:
self.vd_adapter.load_state_dict(state_dict['vd_adapter'], strict=strict)
if 'dvadapter' in state_dict:
@@ -378,6 +390,11 @@ class CustomAdapter(torch.nn.Module):
elif self.adapter_type == 'text_encoder':
state_dict["te_adapter"] = self.te_adapter.state_dict()
return state_dict
elif self.adapter_type == 'te_augmenter':
if self.config.train_image_encoder:
state_dict["vision_encoder"] = self.vision_encoder.state_dict()
state_dict["te_augmenter"] = self.te_augmenter.state_dict()
return state_dict
elif self.adapter_type == 'vision_direct':
state_dict["dvadapter"] = self.vd_adapter.state_dict()
if self.config.train_image_encoder:
@@ -647,7 +664,7 @@ class CustomAdapter(torch.nn.Module):
has_been_preprocessed=False,
quad_count=4,
) -> PromptEmbeds:
if self.adapter_type == 'ilora' or self.adapter_type == 'vision_direct':
if self.adapter_type == 'ilora' or self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter':
if tensors_0_1 is None:
tensors_0_1 = self.get_empty_clip_image(1)
has_been_preprocessed = True
@@ -675,7 +692,7 @@ class CustomAdapter(torch.nn.Module):
clip_image = tensors_0_1
batch_size = clip_image.shape[0]
if self.adapter_type == 'vision_direct':
if self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter':
# add an unconditional so we can save it
unconditional = self.get_empty_clip_image(batch_size).to(
clip_image.device, dtype=clip_image.dtype
@@ -730,7 +747,7 @@ class CustomAdapter(torch.nn.Module):
img_embeds = img_embeds.detach()
self.ilora_module(img_embeds)
if self.adapter_type == 'vision_direct':
if self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter':
with torch.set_grad_enabled(is_training):
if is_training and self.config.train_image_encoder:
self.vision_encoder.train()
@@ -754,8 +771,14 @@ class CustomAdapter(torch.nn.Module):
if not is_training or not self.config.train_image_encoder:
clip_image_embeds = clip_image_embeds.detach()
if self.adapter_type == 'te_augmenter':
clip_image_embeds = self.te_augmenter(clip_image_embeds)
# save them to the conditional and unconditional
self.unconditional_embeds, self.conditional_embeds = clip_image_embeds.chunk(2, dim=0)
try:
self.unconditional_embeds, self.conditional_embeds = clip_image_embeds.chunk(2, dim=0)
except ValueError:
raise ValueError(f"could not split the clip image embeds into 2. Got shape: {clip_image_embeds.shape}")
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
if self.config.train_only_image_encoder:
@@ -781,6 +804,10 @@ class CustomAdapter(torch.nn.Module):
yield from attn_processor.parameters(recurse)
if self.config.train_image_encoder:
yield from self.vision_encoder.parameters(recurse)
elif self.config.type == 'te_augmenter':
yield from self.te_augmenter.parameters(recurse)
if self.config.train_image_encoder:
yield from self.vision_encoder.parameters(recurse)
else:
raise NotImplementedError

View File

@@ -0,0 +1,253 @@
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import weakref
from typing import Union, TYPE_CHECKING, Optional, Tuple
from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer
from transformers.models.clip.modeling_clip import CLIPEncoder, CLIPAttention
from toolkit.models.zipper_resampler import ZipperResampler, ZipperModule
from toolkit.paths import REPOS_ROOT
from toolkit.resampler import Resampler
sys.path.append(REPOS_ROOT)
from ipadapter.ip_adapter.attention_processor import AttnProcessor2_0
if TYPE_CHECKING:
from toolkit.stable_diffusion_model import StableDiffusion
from toolkit.custom_adapter import CustomAdapter
class TEAugAdapterCLIPAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, attn_module: 'CLIPAttention', adapter: 'TEAugAdapter'):
super().__init__()
self.adapter_ref: weakref.ref = weakref.ref(adapter)
self.attn_module_ref: weakref.ref = weakref.ref(attn_module)
self.k_proj_adapter = nn.Linear(attn_module.embed_dim, attn_module.embed_dim)
self.v_proj_adapter = nn.Linear(attn_module.embed_dim, attn_module.embed_dim)
# copy the weights from the original module
self.k_proj_adapter.weight.data = attn_module.k_proj.weight.data.clone() * 0.01
self.v_proj_adapter.weight.data = attn_module.v_proj.weight.data.clone() * 0.01
#reset the bias
self.k_proj_adapter.bias.data = attn_module.k_proj.bias.data.clone() * 0.001
self.v_proj_adapter.bias.data = attn_module.v_proj.bias.data.clone() * 0.001
self.zipper = ZipperModule(
in_size=attn_module.embed_dim,
in_tokens=77 * 2,
out_size=attn_module.embed_dim,
out_tokens=77,
hidden_size=attn_module.embed_dim,
hidden_tokens=77,
)
# self.k_proj_adapter.weight.data = torch.zeros_like(attn_module.k_proj.weight.data)
# self.v_proj_adapter.weight.data = torch.zeros_like(attn_module.v_proj.weight.data)
# #reset the bias
# self.k_proj_adapter.bias.data = torch.zeros_like(attn_module.k_proj.bias.data)
# self.v_proj_adapter.bias.data = torch.zeros_like(attn_module.v_proj.bias.data)
# replace the original forward with our forward
self.original_forward = attn_module.forward
attn_module.forward = self.forward
@property
def is_active(self):
return self.adapter_ref().is_active
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
attn_module = self.attn_module_ref()
bsz, tgt_len, embed_dim = hidden_states.size()
# get query proj
query_states = attn_module.q_proj(hidden_states) * attn_module.scale
key_states = attn_module._shape(attn_module.k_proj(hidden_states), -1, bsz)
value_states = attn_module._shape(attn_module.v_proj(hidden_states), -1, bsz)
proj_shape = (bsz * attn_module.num_heads, -1, attn_module.head_dim)
query_states = attn_module._shape(query_states, tgt_len, bsz).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
src_len = key_states.size(1)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
if attn_weights.size() != (bsz * attn_module.num_heads, tgt_len, src_len):
raise ValueError(
f"Attention weights should be of size {(bsz * attn_module.num_heads, tgt_len, src_len)}, but is"
f" {attn_weights.size()}"
)
# apply the causal_attention_mask first
if causal_attention_mask is not None:
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is"
f" {causal_attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, attn_module.num_heads, tgt_len, src_len) + causal_attention_mask
attn_weights = attn_weights.view(bsz * attn_module.num_heads, tgt_len, src_len)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights.view(bsz, attn_module.num_heads, tgt_len, src_len) + attention_mask
attn_weights = attn_weights.view(bsz * attn_module.num_heads, tgt_len, src_len)
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
if output_attentions:
# this operation is a bit akward, but it's required to
# make sure that attn_weights keeps its gradient.
# In order to do so, attn_weights have to reshaped
# twice and have to be reused in the following
attn_weights_reshaped = attn_weights.view(bsz, attn_module.num_heads, tgt_len, src_len)
attn_weights = attn_weights_reshaped.view(bsz * attn_module.num_heads, tgt_len, src_len)
else:
attn_weights_reshaped = None
attn_probs = nn.functional.dropout(attn_weights, p=attn_module.dropout, training=self.training)
attn_output = torch.bmm(attn_probs, value_states)
if attn_output.size() != (bsz * attn_module.num_heads, tgt_len, attn_module.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, attn_module.num_heads, tgt_len, attn_module.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.view(bsz, attn_module.num_heads, tgt_len, attn_module.head_dim)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
adapter: 'CustomAdapter' = self.adapter_ref().adapter_ref()
if self.adapter_ref().is_active and adapter.conditional_embeds is not None:
# apply the adapter
if adapter.is_unconditional_run:
embeds = adapter.unconditional_embeds
else:
embeds = adapter.conditional_embeds
# if the shape is not the same on batch, we are doing cfg and need to concat unconditional as well
if embeds.size(0) != bsz:
embeds = torch.cat([adapter.unconditional_embeds, embeds], dim=0)
key_states_raw = self.k_proj_adapter(embeds)
key_states = attn_module._shape(key_states_raw, -1, bsz)
value_states_raw = self.v_proj_adapter(embeds)
value_states = attn_module._shape(value_states_raw, -1, bsz)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_probs = nn.functional.dropout(attn_weights, p=attn_module.dropout, training=self.training)
attn_output_adapter = torch.bmm(attn_probs, value_states)
if attn_output_adapter.size() != (bsz * attn_module.num_heads, tgt_len, attn_module.head_dim):
raise ValueError(
f"`attn_output_adapter` should be of size {(bsz, attn_module.num_heads, tgt_len, attn_module.head_dim)}, but is"
f" {attn_output_adapter.size()}"
)
attn_output_adapter = attn_output_adapter.view(bsz, attn_module.num_heads, tgt_len, attn_module.head_dim)
attn_output_adapter = attn_output_adapter.transpose(1, 2)
attn_output_adapter = attn_output_adapter.reshape(bsz, tgt_len, embed_dim)
attn_output_adapter = self.zipper(torch.cat([attn_output_adapter, attn_output], dim=1))
# attn_output_adapter = attn_module.out_proj(attn_output_adapter)
attn_output = attn_output + attn_output_adapter
attn_output = attn_module.out_proj(attn_output)
return attn_output, attn_weights_reshaped
class TEAugAdapter(torch.nn.Module):
def __init__(
self,
adapter: 'CustomAdapter',
sd: 'StableDiffusion',
):
super(TEAugAdapter, self).__init__()
self.adapter_ref: weakref.ref = weakref.ref(adapter)
self.sd_ref: weakref.ref = weakref.ref(sd)
if isinstance(sd.text_encoder, list):
raise ValueError("Dual text encoders is not yet supported")
# dim will come from text encoder
# dim = sd.unet.config['cross_attention_dim']
text_encoder: CLIPTextModel = sd.text_encoder
dim = text_encoder.config.hidden_size
clip_encoder: CLIPEncoder = text_encoder.text_model.encoder
# dim = clip_encoder.layers[-1].self_attn
if hasattr(adapter.vision_encoder.config, 'hidden_sizes'):
embedding_dim = adapter.vision_encoder.config.hidden_sizes[-1]
else:
embedding_dim = adapter.vision_encoder.config.hidden_size
image_encoder_state_dict = adapter.vision_encoder.state_dict()
# max_seq_len = CLIP tokens + CLS token
in_tokens = 257
if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict:
# clip
in_tokens = int(image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0])
if adapter.config.image_encoder_arch.startswith('convnext'):
in_tokens = 16 * 16
embedding_dim = adapter.vision_encoder.config.hidden_sizes[-1]
out_tokens = adapter.config.num_tokens if adapter.config.num_tokens > 0 else in_tokens
self.image_proj_model = ZipperModule(
in_size=embedding_dim,
in_tokens=in_tokens,
out_size=dim,
out_tokens=out_tokens,
hidden_size=dim,
hidden_tokens=out_tokens,
)
# init adapter modules
attn_procs = {}
for idx, layer in enumerate(clip_encoder.layers):
name = f"clip_attention.{idx}"
attn_procs[name] = TEAugAdapterCLIPAttention(
layer.self_attn,
self
)
self.adapter_modules = torch.nn.ModuleList(list(attn_procs.values()))
# make a getter to see if is active
@property
def is_active(self):
return self.adapter_ref().is_active
def forward(self, input):
# # apply the adapter
input = self.image_proj_model(input)
# self.embeds = input
return input

View File

@@ -585,11 +585,17 @@ class StableDiffusion:
)
# encode the prompt ourselves so we can do fun stuff with embeddings
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False
conditional_embeds = self.encode_prompt(gen_config.prompt, gen_config.prompt_2, force_all=True)
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = True
unconditional_embeds = self.encode_prompt(
gen_config.negative_prompt, gen_config.negative_prompt_2, force_all=True
)
if isinstance(self.adapter, CustomAdapter):
self.adapter.is_unconditional_run = False
# allow any manipulations to take place to embeddings
gen_config.post_process_embeddings(