Work on ipadapters and custom adapters

This commit is contained in:
Jaret Burkett
2024-05-13 06:37:54 -06:00
parent 10e1ecf1e8
commit 5a45c709cd
10 changed files with 150 additions and 67 deletions

View File

@@ -41,20 +41,37 @@ class Embedder(nn.Module):
self.layer_norm = nn.LayerNorm(input_dim)
self.fc1 = nn.Linear(input_dim, mid_dim)
self.gelu = nn.GELU()
self.fc2 = nn.Linear(mid_dim, output_dim * num_output_tokens)
# self.fc2 = nn.Linear(mid_dim, mid_dim)
self.fc2 = nn.Linear(mid_dim, mid_dim)
self.static_tokens = nn.Parameter(torch.randn(num_output_tokens, output_dim))
self.fc2.weight.data.zero_()
self.layer_norm2 = nn.LayerNorm(mid_dim)
self.fc3 = nn.Linear(mid_dim, mid_dim)
self.gelu2 = nn.GELU()
self.fc4 = nn.Linear(mid_dim, output_dim * num_output_tokens)
# set the weights to 0
self.fc3.weight.data.zero_()
self.fc4.weight.data.zero_()
# self.static_tokens = nn.Parameter(torch.zeros(num_output_tokens, output_dim))
# self.scaler = nn.Parameter(torch.zeros(num_output_tokens, output_dim))
def forward(self, x):
if len(x.shape) == 2:
x = x.unsqueeze(1)
x = self.layer_norm(x)
x = self.fc1(x)
x = self.gelu(x)
x = self.fc2(x)
x = x.view(-1, self.num_output_tokens, self.output_dim)
x = self.layer_norm2(x)
x = self.fc3(x)
x = self.gelu2(x)
x = self.fc4(x)
# repeat the static tokens for each batch
static_tokens = torch.stack([self.static_tokens] * x.shape[0])
x = static_tokens + x
x = x.view(-1, self.num_output_tokens, self.output_dim)
return x
@@ -89,6 +106,7 @@ class ClipVisionAdapter(torch.nn.Module):
print(f"Adding {placeholder_tokens} tokens to tokenizer")
print(f"Adding {self.config.num_tokens} tokens to tokenizer")
for text_encoder, tokenizer in zip(self.text_encoder_list, self.tokenizer_list):
num_added_tokens = tokenizer.add_tokens(placeholder_tokens)
if num_added_tokens != self.config.num_tokens:

View File

@@ -246,6 +246,7 @@ class TrainConfig:
self.adapter_assist_name_or_path: Optional[str] = kwargs.get('adapter_assist_name_or_path', None)
self.noise_multiplier = kwargs.get('noise_multiplier', 1.0)
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
self.noisy_latent_multiplier = kwargs.get('noisy_latent_multiplier', 1.0)
self.latent_multiplier = kwargs.get('latent_multiplier', 1.0)
self.negative_prompt = kwargs.get('negative_prompt', None)
self.max_negative_prompts = kwargs.get('max_negative_prompts', 1)

View File

@@ -86,18 +86,19 @@ class Embedding:
self.orig_embeds_params = [x.get_input_embeddings().weight.data.clone() for x in self.text_encoder_list]
def restore_embeddings(self):
# Let's make sure we don't update any embedding weights besides the newly added token
for text_encoder, tokenizer, orig_embeds, placeholder_token_ids in zip(self.text_encoder_list,
self.tokenizer_list,
self.orig_embeds_params,
self.placeholder_token_ids):
index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool)
index_no_updates[
min(placeholder_token_ids): max(placeholder_token_ids) + 1] = False
with torch.no_grad():
with torch.no_grad():
# Let's make sure we don't update any embedding weights besides the newly added token
for text_encoder, tokenizer, orig_embeds, placeholder_token_ids in zip(self.text_encoder_list,
self.tokenizer_list,
self.orig_embeds_params,
self.placeholder_token_ids):
index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool)
index_no_updates[ min(placeholder_token_ids): max(placeholder_token_ids) + 1] = False
text_encoder.get_input_embeddings().weight[
index_no_updates
] = orig_embeds[index_no_updates]
weight = text_encoder.get_input_embeddings().weight
pass
def get_trainable_params(self):
params = []

View File

@@ -387,7 +387,7 @@ class IPAdapter(torch.nn.Module):
cross_attn_dim = 4096 if is_pixart else sd.unet.config['cross_attention_dim']
image_proj_model = MLPProjModelClipFace(
cross_attention_dim=cross_attn_dim,
id_embeddings_dim=1024,
id_embeddings_dim=self.image_encoder.config.projection_dim,
num_tokens=self.config.num_tokens, # usually 4
)
elif adapter_config.type == 'ip+':
@@ -486,7 +486,21 @@ class IPAdapter(torch.nn.Module):
attn_processor_names = []
blocks = []
transformer_blocks = []
for name in attn_processor_keys:
name_split = name.split(".")
block_name = f"{name_split[0]}.{name_split[1]}"
transformer_idx = name_split.index("transformer_blocks") if "transformer_blocks" in name_split else -1
if transformer_idx >= 0:
transformer_name = ".".join(name_split[:2])
transformer_name += "." + ".".join(name_split[transformer_idx:transformer_idx + 2])
if transformer_name not in transformer_blocks:
transformer_blocks.append(transformer_name)
if block_name not in blocks:
blocks.append(block_name)
cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") or name.endswith("attn1") else \
sd.unet.config['cross_attention_dim']
if name.startswith("mid_block"):

View File

@@ -15,6 +15,30 @@ if TYPE_CHECKING:
from toolkit.stable_diffusion_model import StableDiffusion
class ILoRAProjModule(torch.nn.Module):
def __init__(self, num_modules=1, dim=4, embeddings_dim=512):
super().__init__()
self.num_modules = num_modules
self.num_dim = dim
self.norm = torch.nn.LayerNorm(embeddings_dim)
self.proj = torch.nn.Sequential(
torch.nn.Linear(embeddings_dim, embeddings_dim * 2),
torch.nn.GELU(),
torch.nn.Linear(embeddings_dim * 2, num_modules * dim),
)
# Initialize the last linear layer weights near zero
torch.nn.init.uniform_(self.proj[2].weight, a=-0.01, b=0.01)
torch.nn.init.zeros_(self.proj[2].bias)
def forward(self, x):
x = self.norm(x)
x = self.proj(x)
x = x.reshape(-1, self.num_modules, self.num_dim)
return x
class InstantLoRAMidModule(torch.nn.Module):
def __init__(
self,
@@ -54,7 +78,7 @@ class InstantLoRAMidModule(torch.nn.Module):
raise e
# apply tanh to limit values to -1 to 1
# scaler = torch.tanh(scaler)
return x * (scaler + 1.0)
return x * scaler
class InstantLoRAModule(torch.nn.Module):
@@ -92,20 +116,25 @@ class InstantLoRAModule(torch.nn.Module):
# num_blocks=1,
# )
# heads = 20
heads = 12
dim = 1280
output_dim = self.dim
self.resampler = Resampler(
dim=dim,
depth=4,
dim_head=64,
heads=heads,
num_queries=len(lora_modules),
embedding_dim=self.vision_hidden_size,
max_seq_len=self.vision_tokens,
output_dim=output_dim,
ff_mult=4
)
# heads = 12
# dim = 1280
# output_dim = self.dim
self.proj_module = ILoRAProjModule(
num_modules=len(lora_modules),
dim=self.dim,
embeddings_dim=self.vision_hidden_size,
)
# self.resampler = Resampler(
# dim=dim,
# depth=4,
# dim_head=64,
# heads=heads,
# num_queries=len(lora_modules),
# embedding_dim=self.vision_hidden_size,
# max_seq_len=self.vision_tokens,
# output_dim=output_dim,
# ff_mult=4
# )
for idx, lora_module in enumerate(lora_modules):
# add a new mid module that will take the original forward and add a vector to it
@@ -128,6 +157,6 @@ class InstantLoRAModule(torch.nn.Module):
# expand token rank if only rank 2
if len(img_embeds.shape) == 2:
img_embeds = img_embeds.unsqueeze(1)
img_embeds = self.resampler(img_embeds)
img_embeds = self.proj_module(img_embeds)
self.img_embeds = img_embeds

View File

@@ -390,7 +390,7 @@ def sample_images(
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
def apply_noise_offset(noise, noise_offset):
if noise_offset is None or noise_offset < 0.0000001:
if noise_offset is None or (noise_offset < 0.000001 and noise_offset > -0.000001):
return noise
noise = noise + noise_offset * torch.randn((noise.shape[0], noise.shape[1], 1, 1), device=noise.device)
return noise