mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Work on ipadapters and custom adapters
This commit is contained in:
@@ -275,16 +275,18 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
if self.train_config.inverted_mask_prior and prior_pred is not None and has_mask:
|
||||
assert not self.train_config.train_turbo
|
||||
# we need to make the noise prediction be a masked blending of noise and prior_pred
|
||||
stretched_mask_multiplier = value_map(
|
||||
mask_multiplier,
|
||||
batch.file_items[0].dataset_config.mask_min_value,
|
||||
1.0,
|
||||
0.0,
|
||||
1.0
|
||||
)
|
||||
with torch.no_grad():
|
||||
# we need to make the noise prediction be a masked blending of noise and prior_pred
|
||||
stretched_mask_multiplier = value_map(
|
||||
mask_multiplier,
|
||||
batch.file_items[0].dataset_config.mask_min_value,
|
||||
1.0,
|
||||
0.0,
|
||||
1.0
|
||||
)
|
||||
|
||||
prior_mask_multiplier = 1.0 - stretched_mask_multiplier
|
||||
|
||||
prior_mask_multiplier = 1.0 - stretched_mask_multiplier
|
||||
|
||||
# target_mask_multiplier = mask_multiplier
|
||||
# mask_multiplier = 1.0
|
||||
|
||||
@@ -940,6 +940,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
batch.mask_tensor = double_up_tensor(batch.mask_tensor)
|
||||
batch.control_tensor = double_up_tensor(batch.control_tensor)
|
||||
|
||||
noisy_latent_multiplier = self.train_config.noisy_latent_multiplier
|
||||
|
||||
if noisy_latent_multiplier != 1.0:
|
||||
noisy_latents = noisy_latents * noisy_latent_multiplier
|
||||
|
||||
# remove grads for these
|
||||
noisy_latents.requires_grad = False
|
||||
noisy_latents = noisy_latents.detach()
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import gc
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
from typing import ForwardRef, List
|
||||
from typing import ForwardRef, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from safetensors.torch import save_file, load_file
|
||||
@@ -22,6 +22,7 @@ class GenerateConfig:
|
||||
self.sampler = kwargs.get('sampler', 'ddpm')
|
||||
self.width = kwargs.get('width', 512)
|
||||
self.height = kwargs.get('height', 512)
|
||||
self.size_list: Union[List[int], None] = kwargs.get('size_list', None)
|
||||
self.neg = kwargs.get('neg', '')
|
||||
self.seed = kwargs.get('seed', -1)
|
||||
self.guidance_scale = kwargs.get('guidance_scale', 7)
|
||||
@@ -30,6 +31,7 @@ class GenerateConfig:
|
||||
self.neg_2 = kwargs.get('neg_2', None)
|
||||
self.prompts = kwargs.get('prompts', None)
|
||||
self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
|
||||
self.compile = kwargs.get('compile', False)
|
||||
self.ext = kwargs.get('ext', 'png')
|
||||
self.prompt_file = kwargs.get('prompt_file', False)
|
||||
self.prompts_in_file = self.prompts
|
||||
@@ -93,17 +95,26 @@ class GenerateProcess(BaseProcess):
|
||||
self.sd.load_model()
|
||||
|
||||
print("Compiling model...")
|
||||
self.sd.unet = torch.compile(self.sd.unet, mode="reduce-overhead", fullgraph=True)
|
||||
# self.sd.unet = torch.compile(self.sd.unet, mode="reduce-overhead", fullgraph=True)
|
||||
if self.generate_config.compile:
|
||||
self.sd.unet = torch.compile(self.sd.unet, mode="reduce-overhead")
|
||||
|
||||
print(f"Generating {len(self.generate_config.prompts)} images")
|
||||
# build prompt image configs
|
||||
prompt_image_configs = []
|
||||
for prompt in self.generate_config.prompts:
|
||||
width = self.generate_config.width
|
||||
height = self.generate_config.height
|
||||
|
||||
if self.generate_config.size_list is not None:
|
||||
# randomly select a size
|
||||
width, height = random.choice(self.generate_config.size_list)
|
||||
|
||||
prompt_image_configs.append(GenerateImageConfig(
|
||||
prompt=prompt,
|
||||
prompt_2=self.generate_config.prompt_2,
|
||||
width=self.generate_config.width,
|
||||
height=self.generate_config.height,
|
||||
width=width,
|
||||
height=height,
|
||||
num_inference_steps=self.generate_config.sample_steps,
|
||||
guidance_scale=self.generate_config.guidance_scale,
|
||||
negative_prompt=self.generate_config.neg,
|
||||
|
||||
@@ -21,12 +21,14 @@ from toolkit.data_loader import AiToolkitDataset, get_dataloader_from_datasets,
|
||||
trigger_dataloader_setup_epoch
|
||||
from toolkit.config_modules import DatasetConfig
|
||||
import argparse
|
||||
from tqdm import tqdm
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('dataset_folder', type=str, default='input')
|
||||
parser.add_argument('--epochs', type=int, default=1)
|
||||
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
dataset_folder = args.dataset_folder
|
||||
@@ -40,27 +42,27 @@ batch_size = 1
|
||||
dataset_config = DatasetConfig(
|
||||
dataset_path=dataset_folder,
|
||||
resolution=resolution,
|
||||
caption_ext='json',
|
||||
# caption_ext='json',
|
||||
default_caption='default',
|
||||
clip_image_path='/mnt/Datasets/face_pairs2/control_clean',
|
||||
# clip_image_path='/mnt/Datasets2/regs/yetibear_xl_v14/random_aspect/',
|
||||
buckets=True,
|
||||
bucket_tolerance=bucket_tolerance,
|
||||
poi='person',
|
||||
augmentations=[
|
||||
{
|
||||
'method': 'RandomBrightnessContrast',
|
||||
'brightness_limit': (-0.3, 0.3),
|
||||
'contrast_limit': (-0.3, 0.3),
|
||||
'brightness_by_max': False,
|
||||
'p': 1.0
|
||||
},
|
||||
{
|
||||
'method': 'HueSaturationValue',
|
||||
'hue_shift_limit': (-0, 0),
|
||||
'sat_shift_limit': (-40, 40),
|
||||
'val_shift_limit': (-40, 40),
|
||||
'p': 1.0
|
||||
},
|
||||
# poi='person',
|
||||
# augmentations=[
|
||||
# {
|
||||
# 'method': 'RandomBrightnessContrast',
|
||||
# 'brightness_limit': (-0.3, 0.3),
|
||||
# 'contrast_limit': (-0.3, 0.3),
|
||||
# 'brightness_by_max': False,
|
||||
# 'p': 1.0
|
||||
# },
|
||||
# {
|
||||
# 'method': 'HueSaturationValue',
|
||||
# 'hue_shift_limit': (-0, 0),
|
||||
# 'sat_shift_limit': (-40, 40),
|
||||
# 'val_shift_limit': (-40, 40),
|
||||
# 'p': 1.0
|
||||
# },
|
||||
# {
|
||||
# 'method': 'RGBShift',
|
||||
# 'r_shift_limit': (-20, 20),
|
||||
@@ -68,7 +70,7 @@ dataset_config = DatasetConfig(
|
||||
# 'b_shift_limit': (-20, 20),
|
||||
# 'p': 1.0
|
||||
# },
|
||||
]
|
||||
# ]
|
||||
|
||||
|
||||
)
|
||||
@@ -79,7 +81,7 @@ dataloader: DataLoader = get_dataloader_from_datasets([dataset_config], batch_si
|
||||
# run through an epoch ang check sizes
|
||||
dataloader_iterator = iter(dataloader)
|
||||
for epoch in range(args.epochs):
|
||||
for batch in dataloader:
|
||||
for batch in tqdm(dataloader):
|
||||
batch: 'DataLoaderBatchDTO'
|
||||
img_batch = batch.tensor
|
||||
|
||||
@@ -98,7 +100,7 @@ for epoch in range(args.epochs):
|
||||
|
||||
show_img(img)
|
||||
|
||||
time.sleep(1.0)
|
||||
# time.sleep(0.1)
|
||||
# if not last epoch
|
||||
if epoch < args.epochs - 1:
|
||||
trigger_dataloader_setup_epoch(dataloader)
|
||||
|
||||
@@ -41,20 +41,37 @@ class Embedder(nn.Module):
|
||||
self.layer_norm = nn.LayerNorm(input_dim)
|
||||
self.fc1 = nn.Linear(input_dim, mid_dim)
|
||||
self.gelu = nn.GELU()
|
||||
self.fc2 = nn.Linear(mid_dim, output_dim * num_output_tokens)
|
||||
# self.fc2 = nn.Linear(mid_dim, mid_dim)
|
||||
self.fc2 = nn.Linear(mid_dim, mid_dim)
|
||||
|
||||
self.static_tokens = nn.Parameter(torch.randn(num_output_tokens, output_dim))
|
||||
self.fc2.weight.data.zero_()
|
||||
|
||||
self.layer_norm2 = nn.LayerNorm(mid_dim)
|
||||
self.fc3 = nn.Linear(mid_dim, mid_dim)
|
||||
self.gelu2 = nn.GELU()
|
||||
self.fc4 = nn.Linear(mid_dim, output_dim * num_output_tokens)
|
||||
|
||||
# set the weights to 0
|
||||
self.fc3.weight.data.zero_()
|
||||
self.fc4.weight.data.zero_()
|
||||
|
||||
|
||||
# self.static_tokens = nn.Parameter(torch.zeros(num_output_tokens, output_dim))
|
||||
# self.scaler = nn.Parameter(torch.zeros(num_output_tokens, output_dim))
|
||||
|
||||
def forward(self, x):
|
||||
if len(x.shape) == 2:
|
||||
x = x.unsqueeze(1)
|
||||
x = self.layer_norm(x)
|
||||
x = self.fc1(x)
|
||||
x = self.gelu(x)
|
||||
x = self.fc2(x)
|
||||
x = x.view(-1, self.num_output_tokens, self.output_dim)
|
||||
x = self.layer_norm2(x)
|
||||
x = self.fc3(x)
|
||||
x = self.gelu2(x)
|
||||
x = self.fc4(x)
|
||||
|
||||
# repeat the static tokens for each batch
|
||||
static_tokens = torch.stack([self.static_tokens] * x.shape[0])
|
||||
x = static_tokens + x
|
||||
x = x.view(-1, self.num_output_tokens, self.output_dim)
|
||||
|
||||
return x
|
||||
|
||||
@@ -89,6 +106,7 @@ class ClipVisionAdapter(torch.nn.Module):
|
||||
print(f"Adding {placeholder_tokens} tokens to tokenizer")
|
||||
print(f"Adding {self.config.num_tokens} tokens to tokenizer")
|
||||
|
||||
|
||||
for text_encoder, tokenizer in zip(self.text_encoder_list, self.tokenizer_list):
|
||||
num_added_tokens = tokenizer.add_tokens(placeholder_tokens)
|
||||
if num_added_tokens != self.config.num_tokens:
|
||||
|
||||
@@ -246,6 +246,7 @@ class TrainConfig:
|
||||
self.adapter_assist_name_or_path: Optional[str] = kwargs.get('adapter_assist_name_or_path', None)
|
||||
self.noise_multiplier = kwargs.get('noise_multiplier', 1.0)
|
||||
self.img_multiplier = kwargs.get('img_multiplier', 1.0)
|
||||
self.noisy_latent_multiplier = kwargs.get('noisy_latent_multiplier', 1.0)
|
||||
self.latent_multiplier = kwargs.get('latent_multiplier', 1.0)
|
||||
self.negative_prompt = kwargs.get('negative_prompt', None)
|
||||
self.max_negative_prompts = kwargs.get('max_negative_prompts', 1)
|
||||
|
||||
@@ -86,18 +86,19 @@ class Embedding:
|
||||
self.orig_embeds_params = [x.get_input_embeddings().weight.data.clone() for x in self.text_encoder_list]
|
||||
|
||||
def restore_embeddings(self):
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
for text_encoder, tokenizer, orig_embeds, placeholder_token_ids in zip(self.text_encoder_list,
|
||||
self.tokenizer_list,
|
||||
self.orig_embeds_params,
|
||||
self.placeholder_token_ids):
|
||||
index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool)
|
||||
index_no_updates[
|
||||
min(placeholder_token_ids): max(placeholder_token_ids) + 1] = False
|
||||
with torch.no_grad():
|
||||
with torch.no_grad():
|
||||
# Let's make sure we don't update any embedding weights besides the newly added token
|
||||
for text_encoder, tokenizer, orig_embeds, placeholder_token_ids in zip(self.text_encoder_list,
|
||||
self.tokenizer_list,
|
||||
self.orig_embeds_params,
|
||||
self.placeholder_token_ids):
|
||||
index_no_updates = torch.ones((len(tokenizer),), dtype=torch.bool)
|
||||
index_no_updates[ min(placeholder_token_ids): max(placeholder_token_ids) + 1] = False
|
||||
text_encoder.get_input_embeddings().weight[
|
||||
index_no_updates
|
||||
] = orig_embeds[index_no_updates]
|
||||
weight = text_encoder.get_input_embeddings().weight
|
||||
pass
|
||||
|
||||
def get_trainable_params(self):
|
||||
params = []
|
||||
|
||||
@@ -387,7 +387,7 @@ class IPAdapter(torch.nn.Module):
|
||||
cross_attn_dim = 4096 if is_pixart else sd.unet.config['cross_attention_dim']
|
||||
image_proj_model = MLPProjModelClipFace(
|
||||
cross_attention_dim=cross_attn_dim,
|
||||
id_embeddings_dim=1024,
|
||||
id_embeddings_dim=self.image_encoder.config.projection_dim,
|
||||
num_tokens=self.config.num_tokens, # usually 4
|
||||
)
|
||||
elif adapter_config.type == 'ip+':
|
||||
@@ -486,7 +486,21 @@ class IPAdapter(torch.nn.Module):
|
||||
|
||||
attn_processor_names = []
|
||||
|
||||
blocks = []
|
||||
transformer_blocks = []
|
||||
for name in attn_processor_keys:
|
||||
name_split = name.split(".")
|
||||
block_name = f"{name_split[0]}.{name_split[1]}"
|
||||
transformer_idx = name_split.index("transformer_blocks") if "transformer_blocks" in name_split else -1
|
||||
if transformer_idx >= 0:
|
||||
transformer_name = ".".join(name_split[:2])
|
||||
transformer_name += "." + ".".join(name_split[transformer_idx:transformer_idx + 2])
|
||||
if transformer_name not in transformer_blocks:
|
||||
transformer_blocks.append(transformer_name)
|
||||
|
||||
|
||||
if block_name not in blocks:
|
||||
blocks.append(block_name)
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") or name.endswith("attn1") else \
|
||||
sd.unet.config['cross_attention_dim']
|
||||
if name.startswith("mid_block"):
|
||||
|
||||
@@ -15,6 +15,30 @@ if TYPE_CHECKING:
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
|
||||
|
||||
class ILoRAProjModule(torch.nn.Module):
|
||||
def __init__(self, num_modules=1, dim=4, embeddings_dim=512):
|
||||
super().__init__()
|
||||
|
||||
self.num_modules = num_modules
|
||||
self.num_dim = dim
|
||||
self.norm = torch.nn.LayerNorm(embeddings_dim)
|
||||
|
||||
self.proj = torch.nn.Sequential(
|
||||
torch.nn.Linear(embeddings_dim, embeddings_dim * 2),
|
||||
torch.nn.GELU(),
|
||||
torch.nn.Linear(embeddings_dim * 2, num_modules * dim),
|
||||
)
|
||||
# Initialize the last linear layer weights near zero
|
||||
torch.nn.init.uniform_(self.proj[2].weight, a=-0.01, b=0.01)
|
||||
torch.nn.init.zeros_(self.proj[2].bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
x = self.proj(x)
|
||||
x = x.reshape(-1, self.num_modules, self.num_dim)
|
||||
return x
|
||||
|
||||
|
||||
class InstantLoRAMidModule(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -54,7 +78,7 @@ class InstantLoRAMidModule(torch.nn.Module):
|
||||
raise e
|
||||
# apply tanh to limit values to -1 to 1
|
||||
# scaler = torch.tanh(scaler)
|
||||
return x * (scaler + 1.0)
|
||||
return x * scaler
|
||||
|
||||
|
||||
class InstantLoRAModule(torch.nn.Module):
|
||||
@@ -92,20 +116,25 @@ class InstantLoRAModule(torch.nn.Module):
|
||||
# num_blocks=1,
|
||||
# )
|
||||
# heads = 20
|
||||
heads = 12
|
||||
dim = 1280
|
||||
output_dim = self.dim
|
||||
self.resampler = Resampler(
|
||||
dim=dim,
|
||||
depth=4,
|
||||
dim_head=64,
|
||||
heads=heads,
|
||||
num_queries=len(lora_modules),
|
||||
embedding_dim=self.vision_hidden_size,
|
||||
max_seq_len=self.vision_tokens,
|
||||
output_dim=output_dim,
|
||||
ff_mult=4
|
||||
)
|
||||
# heads = 12
|
||||
# dim = 1280
|
||||
# output_dim = self.dim
|
||||
self.proj_module = ILoRAProjModule(
|
||||
num_modules=len(lora_modules),
|
||||
dim=self.dim,
|
||||
embeddings_dim=self.vision_hidden_size,
|
||||
)
|
||||
# self.resampler = Resampler(
|
||||
# dim=dim,
|
||||
# depth=4,
|
||||
# dim_head=64,
|
||||
# heads=heads,
|
||||
# num_queries=len(lora_modules),
|
||||
# embedding_dim=self.vision_hidden_size,
|
||||
# max_seq_len=self.vision_tokens,
|
||||
# output_dim=output_dim,
|
||||
# ff_mult=4
|
||||
# )
|
||||
|
||||
for idx, lora_module in enumerate(lora_modules):
|
||||
# add a new mid module that will take the original forward and add a vector to it
|
||||
@@ -128,6 +157,6 @@ class InstantLoRAModule(torch.nn.Module):
|
||||
# expand token rank if only rank 2
|
||||
if len(img_embeds.shape) == 2:
|
||||
img_embeds = img_embeds.unsqueeze(1)
|
||||
img_embeds = self.resampler(img_embeds)
|
||||
img_embeds = self.proj_module(img_embeds)
|
||||
self.img_embeds = img_embeds
|
||||
|
||||
|
||||
@@ -390,7 +390,7 @@ def sample_images(
|
||||
|
||||
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||
def apply_noise_offset(noise, noise_offset):
|
||||
if noise_offset is None or noise_offset < 0.0000001:
|
||||
if noise_offset is None or (noise_offset < 0.000001 and noise_offset > -0.000001):
|
||||
return noise
|
||||
noise = noise + noise_offset * torch.randn((noise.shape[0], noise.shape[1], 1, 1), device=noise.device)
|
||||
return noise
|
||||
|
||||
Reference in New Issue
Block a user