mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-28 18:21:16 +00:00
Work on ipadapters and custom adapters
This commit is contained in:
@@ -41,20 +41,37 @@ class Embedder(nn.Module):
|
||||
self.layer_norm = nn.LayerNorm(input_dim)
|
||||
self.fc1 = nn.Linear(input_dim, mid_dim)
|
||||
self.gelu = nn.GELU()
|
||||
self.fc2 = nn.Linear(mid_dim, output_dim * num_output_tokens)
|
||||
# self.fc2 = nn.Linear(mid_dim, mid_dim)
|
||||
self.fc2 = nn.Linear(mid_dim, mid_dim)
|
||||
|
||||
self.static_tokens = nn.Parameter(torch.randn(num_output_tokens, output_dim))
|
||||
self.fc2.weight.data.zero_()
|
||||
|
||||
self.layer_norm2 = nn.LayerNorm(mid_dim)
|
||||
self.fc3 = nn.Linear(mid_dim, mid_dim)
|
||||
self.gelu2 = nn.GELU()
|
||||
self.fc4 = nn.Linear(mid_dim, output_dim * num_output_tokens)
|
||||
|
||||
# set the weights to 0
|
||||
self.fc3.weight.data.zero_()
|
||||
self.fc4.weight.data.zero_()
|
||||
|
||||
|
||||
# self.static_tokens = nn.Parameter(torch.zeros(num_output_tokens, output_dim))
|
||||
# self.scaler = nn.Parameter(torch.zeros(num_output_tokens, output_dim))
|
||||
|
||||
def forward(self, x):
|
||||
if len(x.shape) == 2:
|
||||
x = x.unsqueeze(1)
|
||||
x = self.layer_norm(x)
|
||||
x = self.fc1(x)
|
||||
x = self.gelu(x)
|
||||
x = self.fc2(x)
|
||||
x = x.view(-1, self.num_output_tokens, self.output_dim)
|
||||
x = self.layer_norm2(x)
|
||||
x = self.fc3(x)
|
||||
x = self.gelu2(x)
|
||||
x = self.fc4(x)
|
||||
|
||||
# repeat the static tokens for each batch
|
||||
static_tokens = torch.stack([self.static_tokens] * x.shape[0])
|
||||
x = static_tokens + x
|
||||
x = x.view(-1, self.num_output_tokens, self.output_dim)
|
||||
|
||||
return x
|
||||
|
||||
@@ -89,6 +106,7 @@ class ClipVisionAdapter(torch.nn.Module):
|
||||
print(f"Adding {placeholder_tokens} tokens to tokenizer")
|
||||
print(f"Adding {self.config.num_tokens} tokens to tokenizer")
|
||||
|
||||
|
||||
for text_encoder, tokenizer in zip(self.text_encoder_list, self.tokenizer_list):
|
||||
num_added_tokens = tokenizer.add_tokens(placeholder_tokens)
|
||||
if num_added_tokens != self.config.num_tokens:
|
||||
|
||||
@@ -246,6 +246,7 @@ class TrainConfig:
|
||||
self.adapter_assist_name_or_path: Optional[str] = kwargs.get('adapter_assist_name_or_path', None)
|
||||
self.noise_multiplier = kwargs.get('noise_multiplier', 1.0)
|
||||
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
|
||||
self.noisy_latent_multiplier = kwargs.get('noisy_latent_multiplier', 1.0)
|
||||
self.latent_multiplier = kwargs.get('latent_multiplier', 1.0)
|
||||
self.negative_prompt = kwargs.get('negative_prompt', None)
|
||||
self.max_negative_prompts = kwargs.get('max_negative_prompts', 1)
|
||||
|
||||
@@ -86,18 +86,19 @@ class Embedding:
|
||||
self.orig_embeds_params = [x.get_input_embeddings().weight.data.clone() for x in self.text_encoder_list]
|
||||
|
||||
def restore_embeddings(self):
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
for text_encoder, tokenizer, orig_embeds, placeholder_token_ids in zip(self.text_encoder_list,
|
||||
self.tokenizer_list,
|
||||
self.orig_embeds_params,
|
||||
self.placeholder_token_ids):
|
||||
index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool)
|
||||
index_no_updates[
|
||||
min(placeholder_token_ids): max(placeholder_token_ids) + 1] = False
|
||||
with torch.no_grad():
|
||||
with torch.no_grad():
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
for text_encoder, tokenizer, orig_embeds, placeholder_token_ids in zip(self.text_encoder_list,
|
||||
self.tokenizer_list,
|
||||
self.orig_embeds_params,
|
||||
self.placeholder_token_ids):
|
||||
index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool)
|
||||
index_no_updates[ min(placeholder_token_ids): max(placeholder_token_ids) + 1] = False
|
||||
text_encoder.get_input_embeddings().weight[
|
||||
index_no_updates
|
||||
] = orig_embeds[index_no_updates]
|
||||
weight = text_encoder.get_input_embeddings().weight
|
||||
pass
|
||||
|
||||
def get_trainable_params(self):
|
||||
params = []
|
||||
|
||||
@@ -387,7 +387,7 @@ class IPAdapter(torch.nn.Module):
|
||||
cross_attn_dim = 4096 if is_pixart else sd.unet.config['cross_attention_dim']
|
||||
image_proj_model = MLPProjModelClipFace(
|
||||
cross_attention_dim=cross_attn_dim,
|
||||
id_embeddings_dim=1024,
|
||||
id_embeddings_dim=self.image_encoder.config.projection_dim,
|
||||
num_tokens=self.config.num_tokens, # usually 4
|
||||
)
|
||||
elif adapter_config.type == 'ip+':
|
||||
@@ -486,7 +486,21 @@ class IPAdapter(torch.nn.Module):
|
||||
|
||||
attn_processor_names = []
|
||||
|
||||
blocks = []
|
||||
transformer_blocks = []
|
||||
for name in attn_processor_keys:
|
||||
name_split = name.split(".")
|
||||
block_name = f"{name_split[0]}.{name_split[1]}"
|
||||
transformer_idx = name_split.index("transformer_blocks") if "transformer_blocks" in name_split else -1
|
||||
if transformer_idx >= 0:
|
||||
transformer_name = ".".join(name_split[:2])
|
||||
transformer_name += "." + ".".join(name_split[transformer_idx:transformer_idx + 2])
|
||||
if transformer_name not in transformer_blocks:
|
||||
transformer_blocks.append(transformer_name)
|
||||
|
||||
|
||||
if block_name not in blocks:
|
||||
blocks.append(block_name)
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") or name.endswith("attn1") else \
|
||||
sd.unet.config['cross_attention_dim']
|
||||
if name.startswith("mid_block"):
|
||||
|
||||
@@ -15,6 +15,30 @@ if TYPE_CHECKING:
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
|
||||
|
||||
class ILoRAProjModule(torch.nn.Module):
|
||||
def __init__(self, num_modules=1, dim=4, embeddings_dim=512):
|
||||
super().__init__()
|
||||
|
||||
self.num_modules = num_modules
|
||||
self.num_dim = dim
|
||||
self.norm = torch.nn.LayerNorm(embeddings_dim)
|
||||
|
||||
self.proj = torch.nn.Sequential(
|
||||
torch.nn.Linear(embeddings_dim, embeddings_dim * 2),
|
||||
torch.nn.GELU(),
|
||||
torch.nn.Linear(embeddings_dim * 2, num_modules * dim),
|
||||
)
|
||||
# Initialize the last linear layer weights near zero
|
||||
torch.nn.init.uniform_(self.proj[2].weight, a=-0.01, b=0.01)
|
||||
torch.nn.init.zeros_(self.proj[2].bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
x = self.proj(x)
|
||||
x = x.reshape(-1, self.num_modules, self.num_dim)
|
||||
return x
|
||||
|
||||
|
||||
class InstantLoRAMidModule(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -54,7 +78,7 @@ class InstantLoRAMidModule(torch.nn.Module):
|
||||
raise e
|
||||
# apply tanh to limit values to -1 to 1
|
||||
# scaler = torch.tanh(scaler)
|
||||
return x * (scaler + 1.0)
|
||||
return x * scaler
|
||||
|
||||
|
||||
class InstantLoRAModule(torch.nn.Module):
|
||||
@@ -92,20 +116,25 @@ class InstantLoRAModule(torch.nn.Module):
|
||||
# num_blocks=1,
|
||||
# )
|
||||
# heads = 20
|
||||
heads = 12
|
||||
dim = 1280
|
||||
output_dim = self.dim
|
||||
self.resampler = Resampler(
|
||||
dim=dim,
|
||||
depth=4,
|
||||
dim_head=64,
|
||||
heads=heads,
|
||||
num_queries=len(lora_modules),
|
||||
embedding_dim=self.vision_hidden_size,
|
||||
max_seq_len=self.vision_tokens,
|
||||
output_dim=output_dim,
|
||||
ff_mult=4
|
||||
)
|
||||
# heads = 12
|
||||
# dim = 1280
|
||||
# output_dim = self.dim
|
||||
self.proj_module = ILoRAProjModule(
|
||||
num_modules=len(lora_modules),
|
||||
dim=self.dim,
|
||||
embeddings_dim=self.vision_hidden_size,
|
||||
)
|
||||
# self.resampler = Resampler(
|
||||
# dim=dim,
|
||||
# depth=4,
|
||||
# dim_head=64,
|
||||
# heads=heads,
|
||||
# num_queries=len(lora_modules),
|
||||
# embedding_dim=self.vision_hidden_size,
|
||||
# max_seq_len=self.vision_tokens,
|
||||
# output_dim=output_dim,
|
||||
# ff_mult=4
|
||||
# )
|
||||
|
||||
for idx, lora_module in enumerate(lora_modules):
|
||||
# add a new mid module that will take the original forward and add a vector to it
|
||||
@@ -128,6 +157,6 @@ class InstantLoRAModule(torch.nn.Module):
|
||||
# expand token rank if only rank 2
|
||||
if len(img_embeds.shape) == 2:
|
||||
img_embeds = img_embeds.unsqueeze(1)
|
||||
img_embeds = self.resampler(img_embeds)
|
||||
img_embeds = self.proj_module(img_embeds)
|
||||
self.img_embeds = img_embeds
|
||||
|
||||
|
||||
@@ -390,7 +390,7 @@ def sample_images(
|
||||
|
||||
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||
def apply_noise_offset(noise, noise_offset):
|
||||
if noise_offset is None or noise_offset < 0.0000001:
|
||||
if noise_offset is None or (noise_offset < 0.000001 and noise_offset > -0.000001):
|
||||
return noise
|
||||
noise = noise + noise_offset * torch.randn((noise.shape[0], noise.shape[1], 1, 1), device=noise.device)
|
||||
return noise
|
||||
|
||||
Reference in New Issue
Block a user