Work on ipadapters and custom adapters

This commit is contained in:
Jaret Burkett
2024-05-13 06:37:54 -06:00
parent 10e1ecf1e8
commit 5a45c709cd
10 changed files with 150 additions and 67 deletions

View File

@@ -275,16 +275,18 @@ class SDTrainer(BaseSDTrainProcess):
if self.train_config.inverted_mask_prior and prior_pred is not None and has_mask: if self.train_config.inverted_mask_prior and prior_pred is not None and has_mask:
assert not self.train_config.train_turbo assert not self.train_config.train_turbo
# we need to make the noise prediction be a masked blending of noise and prior_pred with torch.no_grad():
stretched_mask_multiplier = value_map( # we need to make the noise prediction be a masked blending of noise and prior_pred
mask_multiplier, stretched_mask_multiplier = value_map(
batch.file_items[0].dataset_config.mask_min_value, mask_multiplier,
1.0, batch.file_items[0].dataset_config.mask_min_value,
0.0, 1.0,
1.0 0.0,
) 1.0
)
prior_mask_multiplier = 1.0 - stretched_mask_multiplier
prior_mask_multiplier = 1.0 - stretched_mask_multiplier
# target_mask_multiplier = mask_multiplier # target_mask_multiplier = mask_multiplier
# mask_multiplier = 1.0 # mask_multiplier = 1.0

View File

@@ -940,6 +940,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
batch.mask_tensor = double_up_tensor(batch.mask_tensor) batch.mask_tensor = double_up_tensor(batch.mask_tensor)
batch.control_tensor = double_up_tensor(batch.control_tensor) batch.control_tensor = double_up_tensor(batch.control_tensor)
noisy_latent_multiplier = self.train_config.noisy_latent_multiplier
if noisy_latent_multiplier != 1.0:
noisy_latents = noisy_latents * noisy_latent_multiplier
# remove grads for these # remove grads for these
noisy_latents.requires_grad = False noisy_latents.requires_grad = False
noisy_latents = noisy_latents.detach() noisy_latents = noisy_latents.detach()

View File

@@ -1,7 +1,7 @@
import gc import gc
import os import os
from collections import OrderedDict from collections import OrderedDict
from typing import ForwardRef, List from typing import ForwardRef, List, Optional, Union
import torch import torch
from safetensors.torch import save_file, load_file from safetensors.torch import save_file, load_file
@@ -22,6 +22,7 @@ class GenerateConfig:
self.sampler = kwargs.get('sampler', 'ddpm') self.sampler = kwargs.get('sampler', 'ddpm')
self.width = kwargs.get('width', 512) self.width = kwargs.get('width', 512)
self.height = kwargs.get('height', 512) self.height = kwargs.get('height', 512)
self.size_list: Union[List[int], None] = kwargs.get('size_list', None)
self.neg = kwargs.get('neg', '') self.neg = kwargs.get('neg', '')
self.seed = kwargs.get('seed', -1) self.seed = kwargs.get('seed', -1)
self.guidance_scale = kwargs.get('guidance_scale', 7) self.guidance_scale = kwargs.get('guidance_scale', 7)
@@ -30,6 +31,7 @@ class GenerateConfig:
self.neg_2 = kwargs.get('neg_2', None) self.neg_2 = kwargs.get('neg_2', None)
self.prompts = kwargs.get('prompts', None) self.prompts = kwargs.get('prompts', None)
self.guidance_rescale = kwargs.get('guidance_rescale', 0.0) self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
self.compile = kwargs.get('compile', False)
self.ext = kwargs.get('ext', 'png') self.ext = kwargs.get('ext', 'png')
self.prompt_file = kwargs.get('prompt_file', False) self.prompt_file = kwargs.get('prompt_file', False)
self.prompts_in_file = self.prompts self.prompts_in_file = self.prompts
@@ -93,17 +95,26 @@ class GenerateProcess(BaseProcess):
self.sd.load_model() self.sd.load_model()
print("Compiling model...") print("Compiling model...")
self.sd.unet = torch.compile(self.sd.unet, mode="reduce-overhead", fullgraph=True) # self.sd.unet = torch.compile(self.sd.unet, mode="reduce-overhead", fullgraph=True)
if self.generate_config.compile:
self.sd.unet = torch.compile(self.sd.unet, mode="reduce-overhead")
print(f"Generating {len(self.generate_config.prompts)} images") print(f"Generating {len(self.generate_config.prompts)} images")
# build prompt image configs # build prompt image configs
prompt_image_configs = [] prompt_image_configs = []
for prompt in self.generate_config.prompts: for prompt in self.generate_config.prompts:
width = self.generate_config.width
height = self.generate_config.height
if self.generate_config.size_list is not None:
# randomly select a size
width, height = random.choice(self.generate_config.size_list)
prompt_image_configs.append(GenerateImageConfig( prompt_image_configs.append(GenerateImageConfig(
prompt=prompt, prompt=prompt,
prompt_2=self.generate_config.prompt_2, prompt_2=self.generate_config.prompt_2,
width=self.generate_config.width, width=width,
height=self.generate_config.height, height=height,
num_inference_steps=self.generate_config.sample_steps, num_inference_steps=self.generate_config.sample_steps,
guidance_scale=self.generate_config.guidance_scale, guidance_scale=self.generate_config.guidance_scale,
negative_prompt=self.generate_config.neg, negative_prompt=self.generate_config.neg,

View File

@@ -21,12 +21,14 @@ from toolkit.data_loader import AiToolkitDataset, get_dataloader_from_datasets,
trigger_dataloader_setup_epoch trigger_dataloader_setup_epoch
from toolkit.config_modules import DatasetConfig from toolkit.config_modules import DatasetConfig
import argparse import argparse
from tqdm import tqdm
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('dataset_folder', type=str, default='input') parser.add_argument('dataset_folder', type=str, default='input')
parser.add_argument('--epochs', type=int, default=1) parser.add_argument('--epochs', type=int, default=1)
args = parser.parse_args() args = parser.parse_args()
dataset_folder = args.dataset_folder dataset_folder = args.dataset_folder
@@ -40,27 +42,27 @@ batch_size = 1
dataset_config = DatasetConfig( dataset_config = DatasetConfig(
dataset_path=dataset_folder, dataset_path=dataset_folder,
resolution=resolution, resolution=resolution,
caption_ext='json', # caption_ext='json',
default_caption='default', default_caption='default',
clip_image_path='/mnt/Datasets/face_pairs2/control_clean', # clip_image_path='/mnt/Datasets2/regs/yetibear_xl_v14/random_aspect/',
buckets=True, buckets=True,
bucket_tolerance=bucket_tolerance, bucket_tolerance=bucket_tolerance,
poi='person', # poi='person',
augmentations=[ # augmentations=[
{ # {
'method': 'RandomBrightnessContrast', # 'method': 'RandomBrightnessContrast',
'brightness_limit': (-0.3, 0.3), # 'brightness_limit': (-0.3, 0.3),
'contrast_limit': (-0.3, 0.3), # 'contrast_limit': (-0.3, 0.3),
'brightness_by_max': False, # 'brightness_by_max': False,
'p': 1.0 # 'p': 1.0
}, # },
{ # {
'method': 'HueSaturationValue', # 'method': 'HueSaturationValue',
'hue_shift_limit': (-0, 0), # 'hue_shift_limit': (-0, 0),
'sat_shift_limit': (-40, 40), # 'sat_shift_limit': (-40, 40),
'val_shift_limit': (-40, 40), # 'val_shift_limit': (-40, 40),
'p': 1.0 # 'p': 1.0
}, # },
# { # {
# 'method': 'RGBShift', # 'method': 'RGBShift',
# 'r_shift_limit': (-20, 20), # 'r_shift_limit': (-20, 20),
@@ -68,7 +70,7 @@ dataset_config = DatasetConfig(
# 'b_shift_limit': (-20, 20), # 'b_shift_limit': (-20, 20),
# 'p': 1.0 # 'p': 1.0
# }, # },
] # ]
) )
@@ -79,7 +81,7 @@ dataloader: DataLoader = get_dataloader_from_datasets([dataset_config], batch_si
# run through an epoch ang check sizes # run through an epoch ang check sizes
dataloader_iterator = iter(dataloader) dataloader_iterator = iter(dataloader)
for epoch in range(args.epochs): for epoch in range(args.epochs):
for batch in dataloader: for batch in tqdm(dataloader):
batch: 'DataLoaderBatchDTO' batch: 'DataLoaderBatchDTO'
img_batch = batch.tensor img_batch = batch.tensor
@@ -98,7 +100,7 @@ for epoch in range(args.epochs):
show_img(img) show_img(img)
time.sleep(1.0) # time.sleep(0.1)
# if not last epoch # if not last epoch
if epoch < args.epochs - 1: if epoch < args.epochs - 1:
trigger_dataloader_setup_epoch(dataloader) trigger_dataloader_setup_epoch(dataloader)

View File

@@ -41,20 +41,37 @@ class Embedder(nn.Module):
self.layer_norm = nn.LayerNorm(input_dim) self.layer_norm = nn.LayerNorm(input_dim)
self.fc1 = nn.Linear(input_dim, mid_dim) self.fc1 = nn.Linear(input_dim, mid_dim)
self.gelu = nn.GELU() self.gelu = nn.GELU()
self.fc2 = nn.Linear(mid_dim, output_dim * num_output_tokens) # self.fc2 = nn.Linear(mid_dim, mid_dim)
self.fc2 = nn.Linear(mid_dim, mid_dim)
self.static_tokens = nn.Parameter(torch.randn(num_output_tokens, output_dim)) self.fc2.weight.data.zero_()
self.layer_norm2 = nn.LayerNorm(mid_dim)
self.fc3 = nn.Linear(mid_dim, mid_dim)
self.gelu2 = nn.GELU()
self.fc4 = nn.Linear(mid_dim, output_dim * num_output_tokens)
# set the weights to 0
self.fc3.weight.data.zero_()
self.fc4.weight.data.zero_()
# self.static_tokens = nn.Parameter(torch.zeros(num_output_tokens, output_dim))
# self.scaler = nn.Parameter(torch.zeros(num_output_tokens, output_dim))
def forward(self, x): def forward(self, x):
if len(x.shape) == 2:
x = x.unsqueeze(1)
x = self.layer_norm(x) x = self.layer_norm(x)
x = self.fc1(x) x = self.fc1(x)
x = self.gelu(x) x = self.gelu(x)
x = self.fc2(x) x = self.fc2(x)
x = x.view(-1, self.num_output_tokens, self.output_dim) x = self.layer_norm2(x)
x = self.fc3(x)
x = self.gelu2(x)
x = self.fc4(x)
# repeat the static tokens for each batch x = x.view(-1, self.num_output_tokens, self.output_dim)
static_tokens = torch.stack([self.static_tokens] * x.shape[0])
x = static_tokens + x
return x return x
@@ -89,6 +106,7 @@ class ClipVisionAdapter(torch.nn.Module):
print(f"Adding {placeholder_tokens} tokens to tokenizer") print(f"Adding {placeholder_tokens} tokens to tokenizer")
print(f"Adding {self.config.num_tokens} tokens to tokenizer") print(f"Adding {self.config.num_tokens} tokens to tokenizer")
for text_encoder, tokenizer in zip(self.text_encoder_list, self.tokenizer_list): for text_encoder, tokenizer in zip(self.text_encoder_list, self.tokenizer_list):
num_added_tokens = tokenizer.add_tokens(placeholder_tokens) num_added_tokens = tokenizer.add_tokens(placeholder_tokens)
if num_added_tokens != self.config.num_tokens: if num_added_tokens != self.config.num_tokens:

View File

@@ -246,6 +246,7 @@ class TrainConfig:
self.adapter_assist_name_or_path: Optional[str] = kwargs.get('adapter_assist_name_or_path', None) self.adapter_assist_name_or_path: Optional[str] = kwargs.get('adapter_assist_name_or_path', None)
self.noise_multiplier = kwargs.get('noise_multiplier', 1.0) self.noise_multiplier = kwargs.get('noise_multiplier', 1.0)
self.img_multiplier = kwargs.get('img_multiplier', 1.0) self.img_multiplier = kwargs.get('img_multiplier', 1.0)
self.noisy_latent_multiplier = kwargs.get('noisy_latent_multiplier', 1.0)
self.latent_multiplier = kwargs.get('latent_multiplier', 1.0) self.latent_multiplier = kwargs.get('latent_multiplier', 1.0)
self.negative_prompt = kwargs.get('negative_prompt', None) self.negative_prompt = kwargs.get('negative_prompt', None)
self.max_negative_prompts = kwargs.get('max_negative_prompts', 1) self.max_negative_prompts = kwargs.get('max_negative_prompts', 1)

View File

@@ -86,18 +86,19 @@ class Embedding:
self.orig_embeds_params = [x.get_input_embeddings().weight.data.clone() for x in self.text_encoder_list] self.orig_embeds_params = [x.get_input_embeddings().weight.data.clone() for x in self.text_encoder_list]
def restore_embeddings(self): def restore_embeddings(self):
# Let's make sure we don't update any embedding weights besides the newly added token with torch.no_grad():
for text_encoder, tokenizer, orig_embeds, placeholder_token_ids in zip(self.text_encoder_list, # Let's make sure we don't update any embedding weights besides the newly added token
self.tokenizer_list, for text_encoder, tokenizer, orig_embeds, placeholder_token_ids in zip(self.text_encoder_list,
self.orig_embeds_params, self.tokenizer_list,
self.placeholder_token_ids): self.orig_embeds_params,
index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool) self.placeholder_token_ids):
index_no_updates[ index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool)
min(placeholder_token_ids): max(placeholder_token_ids) + 1] = False index_no_updates[ min(placeholder_token_ids): max(placeholder_token_ids) + 1] = False
with torch.no_grad():
text_encoder.get_input_embeddings().weight[ text_encoder.get_input_embeddings().weight[
index_no_updates index_no_updates
] = orig_embeds[index_no_updates] ] = orig_embeds[index_no_updates]
weight = text_encoder.get_input_embeddings().weight
pass
def get_trainable_params(self): def get_trainable_params(self):
params = [] params = []

View File

@@ -387,7 +387,7 @@ class IPAdapter(torch.nn.Module):
cross_attn_dim = 4096 if is_pixart else sd.unet.config['cross_attention_dim'] cross_attn_dim = 4096 if is_pixart else sd.unet.config['cross_attention_dim']
image_proj_model = MLPProjModelClipFace( image_proj_model = MLPProjModelClipFace(
cross_attention_dim=cross_attn_dim, cross_attention_dim=cross_attn_dim,
id_embeddings_dim=1024, id_embeddings_dim=self.image_encoder.config.projection_dim,
num_tokens=self.config.num_tokens, # usually 4 num_tokens=self.config.num_tokens, # usually 4
) )
elif adapter_config.type == 'ip+': elif adapter_config.type == 'ip+':
@@ -486,7 +486,21 @@ class IPAdapter(torch.nn.Module):
attn_processor_names = [] attn_processor_names = []
blocks = []
transformer_blocks = []
for name in attn_processor_keys: for name in attn_processor_keys:
name_split = name.split(".")
block_name = f"{name_split[0]}.{name_split[1]}"
transformer_idx = name_split.index("transformer_blocks") if "transformer_blocks" in name_split else -1
if transformer_idx >= 0:
transformer_name = ".".join(name_split[:2])
transformer_name += "." + ".".join(name_split[transformer_idx:transformer_idx + 2])
if transformer_name not in transformer_blocks:
transformer_blocks.append(transformer_name)
if block_name not in blocks:
blocks.append(block_name)
cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") or name.endswith("attn1") else \ cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") or name.endswith("attn1") else \
sd.unet.config['cross_attention_dim'] sd.unet.config['cross_attention_dim']
if name.startswith("mid_block"): if name.startswith("mid_block"):

View File

@@ -15,6 +15,30 @@ if TYPE_CHECKING:
from toolkit.stable_diffusion_model import StableDiffusion from toolkit.stable_diffusion_model import StableDiffusion
class ILoRAProjModule(torch.nn.Module):
def __init__(self, num_modules=1, dim=4, embeddings_dim=512):
super().__init__()
self.num_modules = num_modules
self.num_dim = dim
self.norm = torch.nn.LayerNorm(embeddings_dim)
self.proj = torch.nn.Sequential(
torch.nn.Linear(embeddings_dim, embeddings_dim * 2),
torch.nn.GELU(),
torch.nn.Linear(embeddings_dim * 2, num_modules * dim),
)
# Initialize the last linear layer weights near zero
torch.nn.init.uniform_(self.proj[2].weight, a=-0.01, b=0.01)
torch.nn.init.zeros_(self.proj[2].bias)
def forward(self, x):
x = self.norm(x)
x = self.proj(x)
x = x.reshape(-1, self.num_modules, self.num_dim)
return x
class InstantLoRAMidModule(torch.nn.Module): class InstantLoRAMidModule(torch.nn.Module):
def __init__( def __init__(
self, self,
@@ -54,7 +78,7 @@ class InstantLoRAMidModule(torch.nn.Module):
raise e raise e
# apply tanh to limit values to -1 to 1 # apply tanh to limit values to -1 to 1
# scaler = torch.tanh(scaler) # scaler = torch.tanh(scaler)
return x * (scaler + 1.0) return x * scaler
class InstantLoRAModule(torch.nn.Module): class InstantLoRAModule(torch.nn.Module):
@@ -92,20 +116,25 @@ class InstantLoRAModule(torch.nn.Module):
# num_blocks=1, # num_blocks=1,
# ) # )
# heads = 20 # heads = 20
heads = 12 # heads = 12
dim = 1280 # dim = 1280
output_dim = self.dim # output_dim = self.dim
self.resampler = Resampler( self.proj_module = ILoRAProjModule(
dim=dim, num_modules=len(lora_modules),
depth=4, dim=self.dim,
dim_head=64, embeddings_dim=self.vision_hidden_size,
heads=heads, )
num_queries=len(lora_modules), # self.resampler = Resampler(
embedding_dim=self.vision_hidden_size, # dim=dim,
max_seq_len=self.vision_tokens, # depth=4,
output_dim=output_dim, # dim_head=64,
ff_mult=4 # heads=heads,
) # num_queries=len(lora_modules),
# embedding_dim=self.vision_hidden_size,
# max_seq_len=self.vision_tokens,
# output_dim=output_dim,
# ff_mult=4
# )
for idx, lora_module in enumerate(lora_modules): for idx, lora_module in enumerate(lora_modules):
# add a new mid module that will take the original forward and add a vector to it # add a new mid module that will take the original forward and add a vector to it
@@ -128,6 +157,6 @@ class InstantLoRAModule(torch.nn.Module):
# expand token rank if only rank 2 # expand token rank if only rank 2
if len(img_embeds.shape) == 2: if len(img_embeds.shape) == 2:
img_embeds = img_embeds.unsqueeze(1) img_embeds = img_embeds.unsqueeze(1)
img_embeds = self.resampler(img_embeds) img_embeds = self.proj_module(img_embeds)
self.img_embeds = img_embeds self.img_embeds = img_embeds

View File

@@ -390,7 +390,7 @@ def sample_images(
# https://www.crosslabs.org//blog/diffusion-with-offset-noise # https://www.crosslabs.org//blog/diffusion-with-offset-noise
def apply_noise_offset(noise, noise_offset): def apply_noise_offset(noise, noise_offset):
if noise_offset is None or noise_offset < 0.0000001: if noise_offset is None or (noise_offset < 0.000001 and noise_offset > -0.000001):
return noise return noise
noise = noise + noise_offset * torch.randn((noise.shape[0], noise.shape[1], 1, 1), device=noise.device) noise = noise + noise_offset * torch.randn((noise.shape[0], noise.shape[1], 1, 1), device=noise.device)
return noise return noise