Work on ipadapters and custom adapters

This commit is contained in:
Jaret Burkett
2024-05-13 06:37:54 -06:00
parent 10e1ecf1e8
commit 5a45c709cd
10 changed files with 150 additions and 67 deletions

View File

@@ -275,16 +275,18 @@ class SDTrainer(BaseSDTrainProcess):
if self.train_config.inverted_mask_prior and prior_pred is not None and has_mask:
assert not self.train_config.train_turbo
# we need to make the noise prediction be a masked blending of noise and prior_pred
stretched_mask_multiplier = value_map(
mask_multiplier,
batch.file_items[0].dataset_config.mask_min_value,
1.0,
0.0,
1.0
)
with torch.no_grad():
# we need to make the noise prediction be a masked blending of noise and prior_pred
stretched_mask_multiplier = value_map(
mask_multiplier,
batch.file_items[0].dataset_config.mask_min_value,
1.0,
0.0,
1.0
)
prior_mask_multiplier = 1.0 - stretched_mask_multiplier
prior_mask_multiplier = 1.0 - stretched_mask_multiplier
# target_mask_multiplier = mask_multiplier
# mask_multiplier = 1.0

View File

@@ -940,6 +940,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
batch.mask_tensor = double_up_tensor(batch.mask_tensor)
batch.control_tensor = double_up_tensor(batch.control_tensor)
noisy_latent_multiplier = self.train_config.noisy_latent_multiplier
if noisy_latent_multiplier != 1.0:
noisy_latents = noisy_latents * noisy_latent_multiplier
# remove grads for these
noisy_latents.requires_grad = False
noisy_latents = noisy_latents.detach()

View File

@@ -1,7 +1,7 @@
import gc
import os
from collections import OrderedDict
from typing import ForwardRef, List
from typing import ForwardRef, List, Optional, Union
import torch
from safetensors.torch import save_file, load_file
@@ -22,6 +22,7 @@ class GenerateConfig:
self.sampler = kwargs.get('sampler', 'ddpm')
self.width = kwargs.get('width', 512)
self.height = kwargs.get('height', 512)
self.size_list: Union[List[int], None] = kwargs.get('size_list', None)
self.neg = kwargs.get('neg', '')
self.seed = kwargs.get('seed', -1)
self.guidance_scale = kwargs.get('guidance_scale', 7)
@@ -30,6 +31,7 @@ class GenerateConfig:
self.neg_2 = kwargs.get('neg_2', None)
self.prompts = kwargs.get('prompts', None)
self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
self.compile = kwargs.get('compile', False)
self.ext = kwargs.get('ext', 'png')
self.prompt_file = kwargs.get('prompt_file', False)
self.prompts_in_file = self.prompts
@@ -93,17 +95,26 @@ class GenerateProcess(BaseProcess):
self.sd.load_model()
print("Compiling model...")
self.sd.unet = torch.compile(self.sd.unet, mode="reduce-overhead", fullgraph=True)
# self.sd.unet = torch.compile(self.sd.unet, mode="reduce-overhead", fullgraph=True)
if self.generate_config.compile:
self.sd.unet = torch.compile(self.sd.unet, mode="reduce-overhead")
print(f"Generating {len(self.generate_config.prompts)} images")
# build prompt image configs
prompt_image_configs = []
for prompt in self.generate_config.prompts:
width = self.generate_config.width
height = self.generate_config.height
if self.generate_config.size_list is not None:
# randomly select a size
width, height = random.choice(self.generate_config.size_list)
prompt_image_configs.append(GenerateImageConfig(
prompt=prompt,
prompt_2=self.generate_config.prompt_2,
width=self.generate_config.width,
height=self.generate_config.height,
width=width,
height=height,
num_inference_steps=self.generate_config.sample_steps,
guidance_scale=self.generate_config.guidance_scale,
negative_prompt=self.generate_config.neg,

View File

@@ -21,12 +21,14 @@ from toolkit.data_loader import AiToolkitDataset, get_dataloader_from_datasets,
trigger_dataloader_setup_epoch
from toolkit.config_modules import DatasetConfig
import argparse
from tqdm import tqdm
parser = argparse.ArgumentParser()
parser.add_argument('dataset_folder', type=str, default='input')
parser.add_argument('--epochs', type=int, default=1)
args = parser.parse_args()
dataset_folder = args.dataset_folder
@@ -40,27 +42,27 @@ batch_size = 1
dataset_config = DatasetConfig(
dataset_path=dataset_folder,
resolution=resolution,
caption_ext='json',
# caption_ext='json',
default_caption='default',
clip_image_path='/mnt/Datasets/face_pairs2/control_clean',
# clip_image_path='/mnt/Datasets2/regs/yetibear_xl_v14/random_aspect/',
buckets=True,
bucket_tolerance=bucket_tolerance,
poi='person',
augmentations=[
{
'method': 'RandomBrightnessContrast',
'brightness_limit': (-0.3, 0.3),
'contrast_limit': (-0.3, 0.3),
'brightness_by_max': False,
'p': 1.0
},
{
'method': 'HueSaturationValue',
'hue_shift_limit': (-0, 0),
'sat_shift_limit': (-40, 40),
'val_shift_limit': (-40, 40),
'p': 1.0
},
# poi='person',
# augmentations=[
# {
# 'method': 'RandomBrightnessContrast',
# 'brightness_limit': (-0.3, 0.3),
# 'contrast_limit': (-0.3, 0.3),
# 'brightness_by_max': False,
# 'p': 1.0
# },
# {
# 'method': 'HueSaturationValue',
# 'hue_shift_limit': (-0, 0),
# 'sat_shift_limit': (-40, 40),
# 'val_shift_limit': (-40, 40),
# 'p': 1.0
# },
# {
# 'method': 'RGBShift',
# 'r_shift_limit': (-20, 20),
@@ -68,7 +70,7 @@ dataset_config = DatasetConfig(
# 'b_shift_limit': (-20, 20),
# 'p': 1.0
# },
]
# ]
)
@@ -79,7 +81,7 @@ dataloader: DataLoader = get_dataloader_from_datasets([dataset_config], batch_si
# run through an epoch ang check sizes
dataloader_iterator = iter(dataloader)
for epoch in range(args.epochs):
for batch in dataloader:
for batch in tqdm(dataloader):
batch: 'DataLoaderBatchDTO'
img_batch = batch.tensor
@@ -98,7 +100,7 @@ for epoch in range(args.epochs):
show_img(img)
time.sleep(1.0)
# time.sleep(0.1)
# if not last epoch
if epoch < args.epochs - 1:
trigger_dataloader_setup_epoch(dataloader)

View File

@@ -41,20 +41,37 @@ class Embedder(nn.Module):
self.layer_norm = nn.LayerNorm(input_dim)
self.fc1 = nn.Linear(input_dim, mid_dim)
self.gelu = nn.GELU()
self.fc2 = nn.Linear(mid_dim, output_dim * num_output_tokens)
# self.fc2 = nn.Linear(mid_dim, mid_dim)
self.fc2 = nn.Linear(mid_dim, mid_dim)
self.static_tokens = nn.Parameter(torch.randn(num_output_tokens, output_dim))
self.fc2.weight.data.zero_()
self.layer_norm2 = nn.LayerNorm(mid_dim)
self.fc3 = nn.Linear(mid_dim, mid_dim)
self.gelu2 = nn.GELU()
self.fc4 = nn.Linear(mid_dim, output_dim * num_output_tokens)
# set the weights to 0
self.fc3.weight.data.zero_()
self.fc4.weight.data.zero_()
# self.static_tokens = nn.Parameter(torch.zeros(num_output_tokens, output_dim))
# self.scaler = nn.Parameter(torch.zeros(num_output_tokens, output_dim))
def forward(self, x):
if len(x.shape) == 2:
x = x.unsqueeze(1)
x = self.layer_norm(x)
x = self.fc1(x)
x = self.gelu(x)
x = self.fc2(x)
x = x.view(-1, self.num_output_tokens, self.output_dim)
x = self.layer_norm2(x)
x = self.fc3(x)
x = self.gelu2(x)
x = self.fc4(x)
# repeat the static tokens for each batch
static_tokens = torch.stack([self.static_tokens] * x.shape[0])
x = static_tokens + x
x = x.view(-1, self.num_output_tokens, self.output_dim)
return x
@@ -89,6 +106,7 @@ class ClipVisionAdapter(torch.nn.Module):
print(f"Adding {placeholder_tokens} tokens to tokenizer")
print(f"Adding {self.config.num_tokens} tokens to tokenizer")
for text_encoder, tokenizer in zip(self.text_encoder_list, self.tokenizer_list):
num_added_tokens = tokenizer.add_tokens(placeholder_tokens)
if num_added_tokens != self.config.num_tokens:

View File

@@ -246,6 +246,7 @@ class TrainConfig:
self.adapter_assist_name_or_path: Optional[str] = kwargs.get('adapter_assist_name_or_path', None)
self.noise_multiplier = kwargs.get('noise_multiplier', 1.0)
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
self.noisy_latent_multiplier = kwargs.get('noisy_latent_multiplier', 1.0)
self.latent_multiplier = kwargs.get('latent_multiplier', 1.0)
self.negative_prompt = kwargs.get('negative_prompt', None)
self.max_negative_prompts = kwargs.get('max_negative_prompts', 1)

View File

@@ -86,18 +86,19 @@ class Embedding:
self.orig_embeds_params = [x.get_input_embeddings().weight.data.clone() for x in self.text_encoder_list]
def restore_embeddings(self):
# Let's make sure we don't update any embedding weights besides the newly added token
for text_encoder, tokenizer, orig_embeds, placeholder_token_ids in zip(self.text_encoder_list,
self.tokenizer_list,
self.orig_embeds_params,
self.placeholder_token_ids):
index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool)
index_no_updates[
min(placeholder_token_ids): max(placeholder_token_ids) + 1] = False
with torch.no_grad():
with torch.no_grad():
# Let's make sure we don't update any embedding weights besides the newly added token
for text_encoder, tokenizer, orig_embeds, placeholder_token_ids in zip(self.text_encoder_list,
self.tokenizer_list,
self.orig_embeds_params,
self.placeholder_token_ids):
index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool)
index_no_updates[ min(placeholder_token_ids): max(placeholder_token_ids) + 1] = False
text_encoder.get_input_embeddings().weight[
index_no_updates
] = orig_embeds[index_no_updates]
weight = text_encoder.get_input_embeddings().weight
pass
def get_trainable_params(self):
params = []

View File

@@ -387,7 +387,7 @@ class IPAdapter(torch.nn.Module):
cross_attn_dim = 4096 if is_pixart else sd.unet.config['cross_attention_dim']
image_proj_model = MLPProjModelClipFace(
cross_attention_dim=cross_attn_dim,
id_embeddings_dim=1024,
id_embeddings_dim=self.image_encoder.config.projection_dim,
num_tokens=self.config.num_tokens, # usually 4
)
elif adapter_config.type == 'ip+':
@@ -486,7 +486,21 @@ class IPAdapter(torch.nn.Module):
attn_processor_names = []
blocks = []
transformer_blocks = []
for name in attn_processor_keys:
name_split = name.split(".")
block_name = f"{name_split[0]}.{name_split[1]}"
transformer_idx = name_split.index("transformer_blocks") if "transformer_blocks" in name_split else -1
if transformer_idx >= 0:
transformer_name = ".".join(name_split[:2])
transformer_name += "." + ".".join(name_split[transformer_idx:transformer_idx + 2])
if transformer_name not in transformer_blocks:
transformer_blocks.append(transformer_name)
if block_name not in blocks:
blocks.append(block_name)
cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") or name.endswith("attn1") else \
sd.unet.config['cross_attention_dim']
if name.startswith("mid_block"):

View File

@@ -15,6 +15,30 @@ if TYPE_CHECKING:
from toolkit.stable_diffusion_model import StableDiffusion
class ILoRAProjModule(torch.nn.Module):
def __init__(self, num_modules=1, dim=4, embeddings_dim=512):
super().__init__()
self.num_modules = num_modules
self.num_dim = dim
self.norm = torch.nn.LayerNorm(embeddings_dim)
self.proj = torch.nn.Sequential(
torch.nn.Linear(embeddings_dim, embeddings_dim * 2),
torch.nn.GELU(),
torch.nn.Linear(embeddings_dim * 2, num_modules * dim),
)
# Initialize the last linear layer weights near zero
torch.nn.init.uniform_(self.proj[2].weight, a=-0.01, b=0.01)
torch.nn.init.zeros_(self.proj[2].bias)
def forward(self, x):
x = self.norm(x)
x = self.proj(x)
x = x.reshape(-1, self.num_modules, self.num_dim)
return x
class InstantLoRAMidModule(torch.nn.Module):
def __init__(
self,
@@ -54,7 +78,7 @@ class InstantLoRAMidModule(torch.nn.Module):
raise e
# apply tanh to limit values to -1 to 1
# scaler = torch.tanh(scaler)
return x * (scaler + 1.0)
return x * scaler
class InstantLoRAModule(torch.nn.Module):
@@ -92,20 +116,25 @@ class InstantLoRAModule(torch.nn.Module):
# num_blocks=1,
# )
# heads = 20
heads = 12
dim = 1280
output_dim = self.dim
self.resampler = Resampler(
dim=dim,
depth=4,
dim_head=64,
heads=heads,
num_queries=len(lora_modules),
embedding_dim=self.vision_hidden_size,
max_seq_len=self.vision_tokens,
output_dim=output_dim,
ff_mult=4
)
# heads = 12
# dim = 1280
# output_dim = self.dim
self.proj_module = ILoRAProjModule(
num_modules=len(lora_modules),
dim=self.dim,
embeddings_dim=self.vision_hidden_size,
)
# self.resampler = Resampler(
# dim=dim,
# depth=4,
# dim_head=64,
# heads=heads,
# num_queries=len(lora_modules),
# embedding_dim=self.vision_hidden_size,
# max_seq_len=self.vision_tokens,
# output_dim=output_dim,
# ff_mult=4
# )
for idx, lora_module in enumerate(lora_modules):
# add a new mid module that will take the original forward and add a vector to it
@@ -128,6 +157,6 @@ class InstantLoRAModule(torch.nn.Module):
# expand token rank if only rank 2
if len(img_embeds.shape) == 2:
img_embeds = img_embeds.unsqueeze(1)
img_embeds = self.resampler(img_embeds)
img_embeds = self.proj_module(img_embeds)
self.img_embeds = img_embeds

View File

@@ -390,7 +390,7 @@ def sample_images(
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
def apply_noise_offset(noise, noise_offset):
if noise_offset is None or noise_offset < 0.0000001:
if noise_offset is None or (noise_offset < 0.000001 and noise_offset > -0.000001):
return noise
noise = noise + noise_offset * torch.randn((noise.shape[0], noise.shape[1], 1, 1), device=noise.device)
return noise