mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added specialized scaler training to ip adapters
This commit is contained in:
@@ -1243,8 +1243,10 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
has_been_preprocessed=True,
|
||||
quad_count=quad_count
|
||||
)
|
||||
# else:
|
||||
# raise ValueError("Adapter images now must be loaded with dataloader or be a reg image")
|
||||
else:
|
||||
print("No Clip Image")
|
||||
print([file_item.path for file_item in batch.file_items])
|
||||
raise ValueError("Could not find clip image")
|
||||
|
||||
if not self.adapter_config.train_image_encoder:
|
||||
# we are not training the image encoder, so we need to detach the embeds
|
||||
|
||||
@@ -1293,11 +1293,18 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
if self.adapter_config is not None:
|
||||
self.setup_adapter()
|
||||
if self.adapter_config.train:
|
||||
# set trainable params
|
||||
params.append({
|
||||
'params': self.adapter.parameters(),
|
||||
'lr': self.train_config.adapter_lr
|
||||
})
|
||||
|
||||
if isinstance(self.adapter, IPAdapter):
|
||||
# we have custom LR groups for IPAdapter
|
||||
adapter_param_groups = self.adapter.get_parameter_groups(self.train_config.adapter_lr)
|
||||
for group in adapter_param_groups:
|
||||
params.append(group)
|
||||
else:
|
||||
# set trainable params
|
||||
params.append({
|
||||
'params': self.adapter.parameters(),
|
||||
'lr': self.train_config.adapter_lr
|
||||
})
|
||||
|
||||
if self.train_config.gradient_checkpointing:
|
||||
self.adapter.enable_gradient_checkpointing()
|
||||
|
||||
@@ -43,9 +43,7 @@ def preprocess_config(config: OrderedDict, name: str = None):
|
||||
if "name" not in config["config"] and name is None:
|
||||
raise ValueError("config file must have a config.name key")
|
||||
# we need to replace tags. For now just [name]
|
||||
if name is not None:
|
||||
config["config"]["name"] = name
|
||||
else:
|
||||
if name is None:
|
||||
name = config["config"]["name"]
|
||||
config_string = json.dumps(config)
|
||||
config_string = config_string.replace("[name]", name)
|
||||
|
||||
@@ -181,6 +181,12 @@ class AdapterConfig:
|
||||
self.text_encoder_path: str = kwargs.get('text_encoder_path', None)
|
||||
self.text_encoder_arch: str = kwargs.get('text_encoder_arch', 'clip') # clip t5
|
||||
|
||||
self.train_scaler: bool = kwargs.get('train_scaler', False)
|
||||
self.scaler_lr: Optional[float] = kwargs.get('scaler_lr', None)
|
||||
|
||||
# trains with a scaler to easy channel bias but merges it in on save
|
||||
self.merge_scaler: bool = kwargs.get('merge_scaler', False)
|
||||
|
||||
|
||||
class EmbeddingConfig:
|
||||
def __init__(self, **kwargs):
|
||||
|
||||
@@ -80,9 +80,14 @@ class MLPProjModelClipFace(torch.nn.Module):
|
||||
|
||||
|
||||
class CustomIPAttentionProcessor(IPAttnProcessor2_0):
|
||||
def __init__(self, hidden_size, cross_attention_dim, scale=1.0, num_tokens=4, adapter=None):
|
||||
def __init__(self, hidden_size, cross_attention_dim, scale=1.0, num_tokens=4, adapter=None, train_scaler=False):
|
||||
super().__init__(hidden_size, cross_attention_dim, scale=scale, num_tokens=num_tokens)
|
||||
self.adapter_ref: weakref.ref = weakref.ref(adapter)
|
||||
self.train_scaler = train_scaler
|
||||
if train_scaler:
|
||||
# self.ip_scaler = torch.nn.Parameter(torch.ones([num_tokens], dtype=torch.float32) * 0.9999)
|
||||
self.ip_scaler = torch.nn.Parameter(torch.ones([1], dtype=torch.float32) * 0.9999)
|
||||
self.ip_scaler.requires_grad_(True)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -169,6 +174,13 @@ class CustomIPAttentionProcessor(IPAttnProcessor2_0):
|
||||
|
||||
# will be none if disabled
|
||||
if ip_hidden_states is not None:
|
||||
# apply scaler
|
||||
if self.train_scaler:
|
||||
weight = self.ip_scaler
|
||||
# reshape to (1, self.num_tokens, 1)
|
||||
weight = weight.view(1, -1, 1)
|
||||
ip_hidden_states = ip_hidden_states * weight
|
||||
|
||||
# for ip-adapter
|
||||
ip_key = self.to_k_ip(ip_hidden_states)
|
||||
ip_value = self.to_v_ip(ip_hidden_states)
|
||||
@@ -185,7 +197,8 @@ class CustomIPAttentionProcessor(IPAttnProcessor2_0):
|
||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||
|
||||
hidden_states = hidden_states + self.scale * ip_hidden_states
|
||||
scale = self.scale
|
||||
hidden_states = hidden_states + scale * ip_hidden_states
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
@@ -202,6 +215,21 @@ class CustomIPAttentionProcessor(IPAttnProcessor2_0):
|
||||
|
||||
return hidden_states
|
||||
|
||||
# this ensures that the ip_scaler is not changed when we load the model
|
||||
# def _apply(self, fn):
|
||||
# if hasattr(self, "ip_scaler"):
|
||||
# # Overriding the _apply method to prevent the special_parameter from changing dtype
|
||||
# self.ip_scaler = fn(self.ip_scaler)
|
||||
# # Temporarily set the special_parameter to None to exclude it from default _apply processing
|
||||
# ip_scaler = self.ip_scaler
|
||||
# self.ip_scaler = None
|
||||
# super(CustomIPAttentionProcessor, self)._apply(fn)
|
||||
# # Restore the special_parameter after the default _apply processing
|
||||
# self.ip_scaler = ip_scaler
|
||||
# return self
|
||||
# else:
|
||||
# return super(CustomIPAttentionProcessor, self)._apply(fn)
|
||||
|
||||
|
||||
# loosely based on # ref https://github.com/tencent-ailab/IP-Adapter/blob/main/tutorial_train.py
|
||||
class IPAdapter(torch.nn.Module):
|
||||
@@ -485,7 +513,8 @@ class IPAdapter(torch.nn.Module):
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
scale=1.0,
|
||||
num_tokens=self.config.num_tokens,
|
||||
adapter=self
|
||||
adapter=self,
|
||||
train_scaler=self.config.train_scaler or self.config.merge_scaler
|
||||
)
|
||||
if self.sd_ref().is_pixart:
|
||||
# pixart is much more sensitive
|
||||
@@ -494,7 +523,7 @@ class IPAdapter(torch.nn.Module):
|
||||
"to_v_ip.weight": weights["to_v_ip.weight"] * 0.01,
|
||||
}
|
||||
|
||||
attn_procs[name].load_state_dict(weights)
|
||||
attn_procs[name].load_state_dict(weights, strict=False)
|
||||
attn_processor_names.append(name)
|
||||
print(f"Attn Processors")
|
||||
print(attn_processor_names)
|
||||
@@ -568,9 +597,34 @@ class IPAdapter(torch.nn.Module):
|
||||
state_dict = OrderedDict()
|
||||
if self.config.train_only_image_encoder:
|
||||
return self.image_encoder.state_dict()
|
||||
if self.config.train_scaler:
|
||||
state_dict["ip_scale"] = self.adapter_modules.state_dict()
|
||||
# remove items that are not scalers
|
||||
for key in list(state_dict["ip_scale"].keys()):
|
||||
if not key.endswith("ip_scaler"):
|
||||
del state_dict["ip_scale"][key]
|
||||
return state_dict
|
||||
|
||||
state_dict["image_proj"] = self.image_proj_model.state_dict()
|
||||
state_dict["ip_adapter"] = self.adapter_modules.state_dict()
|
||||
# handle merge scaler training
|
||||
if self.config.merge_scaler:
|
||||
for key in list(state_dict["ip_adapter"].keys()):
|
||||
if key.endswith("ip_scaler"):
|
||||
# merge in the scaler so we dont have to save it and it will be compatible with other ip adapters
|
||||
scale = state_dict["ip_adapter"][key].clone()
|
||||
|
||||
key_start = key.split(".")[-2]
|
||||
# reshape to (1, 1)
|
||||
scale = scale.view(1, 1)
|
||||
del state_dict["ip_adapter"][key]
|
||||
# find the to_k_ip and to_v_ip keys
|
||||
for key2 in list(state_dict["ip_adapter"].keys()):
|
||||
if key2.endswith(f"{key_start}.to_k_ip.weight"):
|
||||
state_dict["ip_adapter"][key2] = state_dict["ip_adapter"][key2].clone() * scale
|
||||
if key2.endswith(f"{key_start}.to_v_ip.weight"):
|
||||
state_dict["ip_adapter"][key2] = state_dict["ip_adapter"][key2].clone() * scale
|
||||
|
||||
if self.config.train_image_encoder:
|
||||
state_dict["image_encoder"] = self.image_encoder.state_dict()
|
||||
if self.preprocessor is not None:
|
||||
@@ -866,18 +920,61 @@ class IPAdapter(torch.nn.Module):
|
||||
self.image_proj_model.train(mode)
|
||||
return super().train(mode)
|
||||
|
||||
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
||||
def get_parameter_groups(self, adapter_lr):
|
||||
param_groups = []
|
||||
# when training just scaler, we do not train anything else
|
||||
if not self.config.train_scaler:
|
||||
param_groups.append({
|
||||
"params": self.get_non_scaler_parameters(),
|
||||
"lr": adapter_lr,
|
||||
})
|
||||
if self.config.train_scaler or self.config.merge_scaler:
|
||||
scaler_lr = adapter_lr if self.config.scaler_lr is None else self.config.scaler_lr
|
||||
param_groups.append({
|
||||
"params": self.get_scaler_parameters(),
|
||||
"lr": scaler_lr,
|
||||
})
|
||||
return param_groups
|
||||
|
||||
def get_scaler_parameters(self):
|
||||
# only get the scalera from the adapter modules
|
||||
for attn_processor in self.adapter_modules:
|
||||
# only get the scaler
|
||||
# check if it has ip_scaler attribute
|
||||
if hasattr(attn_processor, "ip_scaler"):
|
||||
scaler_param = attn_processor.ip_scaler
|
||||
yield scaler_param
|
||||
|
||||
def get_non_scaler_parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
||||
if self.config.train_only_image_encoder:
|
||||
yield from self.image_encoder.parameters(recurse)
|
||||
return
|
||||
if self.config.train_scaler:
|
||||
# no params
|
||||
return
|
||||
|
||||
for attn_processor in self.adapter_modules:
|
||||
yield from attn_processor.parameters(recurse)
|
||||
if self.config.train_scaler or self.config.merge_scaler:
|
||||
# todo remove scaler
|
||||
if hasattr(attn_processor, "to_k_ip"):
|
||||
# yield the linear layer
|
||||
yield from attn_processor.to_k_ip.parameters(recurse)
|
||||
if hasattr(attn_processor, "to_v_ip"):
|
||||
# yield the linear layer
|
||||
yield from attn_processor.to_v_ip.parameters(recurse)
|
||||
else:
|
||||
yield from attn_processor.parameters(recurse)
|
||||
yield from self.image_proj_model.parameters(recurse)
|
||||
if self.config.train_image_encoder:
|
||||
yield from self.image_encoder.parameters(recurse)
|
||||
if self.preprocessor is not None:
|
||||
yield from self.preprocessor.parameters(recurse)
|
||||
|
||||
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
||||
yield from self.get_non_scaler_parameters(recurse)
|
||||
if self.config.train_scaler or self.config.merge_scaler:
|
||||
yield from self.get_scaler_parameters()
|
||||
|
||||
def merge_in_weights(self, state_dict: Mapping[str, Any]):
|
||||
# merge in img_proj weights
|
||||
current_img_proj_state_dict = self.image_proj_model.state_dict()
|
||||
@@ -975,6 +1072,8 @@ class IPAdapter(torch.nn.Module):
|
||||
|
||||
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
||||
strict = False
|
||||
if self.config.train_scaler and 'ip_scale' in state_dict:
|
||||
self.adapter_modules.load_state_dict(state_dict["ip_scale"], strict=False)
|
||||
if 'ip_adapter' in state_dict:
|
||||
try:
|
||||
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict)
|
||||
|
||||
@@ -5,6 +5,10 @@ import torch.nn as nn
|
||||
from typing import TYPE_CHECKING
|
||||
from toolkit.models.clip_fusion import ZipperBlock
|
||||
from toolkit.models.zipper_resampler import ZipperModule, ZipperResampler
|
||||
import sys
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
sys.path.append(REPOS_ROOT)
|
||||
from ipadapter.ip_adapter.resampler import Resampler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.lora_special import LoRAModule
|
||||
@@ -50,7 +54,7 @@ class InstantLoRAMidModule(torch.nn.Module):
|
||||
raise e
|
||||
# apply tanh to limit values to -1 to 1
|
||||
# scaler = torch.tanh(scaler)
|
||||
return x * scaler
|
||||
return x * (scaler + 1.0)
|
||||
|
||||
|
||||
class InstantLoRAModule(torch.nn.Module):
|
||||
@@ -78,15 +82,30 @@ class InstantLoRAModule(torch.nn.Module):
|
||||
lora_modules = self.sd_ref().network.get_all_modules()
|
||||
|
||||
# resample the output so each module gets one token with a size of its dim so we can multiply by that
|
||||
self.resampler = ZipperResampler(
|
||||
in_size=self.vision_hidden_size,
|
||||
in_tokens=self.vision_tokens,
|
||||
out_size=self.dim,
|
||||
out_tokens=len(lora_modules),
|
||||
hidden_size=self.vision_hidden_size,
|
||||
hidden_tokens=self.vision_tokens,
|
||||
num_blocks=1,
|
||||
)
|
||||
# self.resampler = ZipperResampler(
|
||||
# in_size=self.vision_hidden_size,
|
||||
# in_tokens=self.vision_tokens,
|
||||
# out_size=self.dim,
|
||||
# out_tokens=len(lora_modules),
|
||||
# hidden_size=self.vision_hidden_size,
|
||||
# hidden_tokens=self.vision_tokens,
|
||||
# num_blocks=1,
|
||||
# )
|
||||
# heads = 20
|
||||
heads = 12
|
||||
dim = 1280
|
||||
output_dim = self.dim
|
||||
self.resampler = Resampler(
|
||||
dim=dim,
|
||||
depth=4,
|
||||
dim_head=64,
|
||||
heads=heads,
|
||||
num_queries=len(lora_modules),
|
||||
embedding_dim=self.vision_hidden_size,
|
||||
max_seq_len=self.vision_tokens,
|
||||
output_dim=output_dim,
|
||||
ff_mult=4
|
||||
)
|
||||
|
||||
for idx, lora_module in enumerate(lora_modules):
|
||||
# add a new mid module that will take the original forward and add a vector to it
|
||||
|
||||
@@ -39,7 +39,7 @@ from toolkit.pipelines import CustomStableDiffusionXLPipeline, CustomStableDiffu
|
||||
StableDiffusionKDiffusionXLPipeline, StableDiffusionXLRefinerPipeline
|
||||
from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, T2IAdapter, DDPMScheduler, \
|
||||
StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline, DiffusionPipeline, \
|
||||
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel
|
||||
StableDiffusionXLImg2ImgPipeline, LCMScheduler, Transformer2DModel, AutoencoderTiny
|
||||
import diffusers
|
||||
from diffusers import \
|
||||
AutoencoderKL, \
|
||||
@@ -872,8 +872,10 @@ class StableDiffusion:
|
||||
is_input_scaled=False,
|
||||
detach_unconditional=False,
|
||||
rescale_cfg=None,
|
||||
return_conditional_pred=False,
|
||||
**kwargs,
|
||||
):
|
||||
conditional_pred = None
|
||||
# get the embeddings
|
||||
if text_embeddings is None and conditional_embeddings is None:
|
||||
raise ValueError("Either text_embeddings or conditional_embeddings must be specified")
|
||||
@@ -1024,9 +1026,12 @@ class StableDiffusion:
|
||||
**kwargs,
|
||||
).sample
|
||||
|
||||
conditional_pred = noise_pred
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# perform guidance
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
conditional_pred = noise_pred_text
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
@@ -1112,9 +1117,12 @@ class StableDiffusion:
|
||||
**kwargs,
|
||||
).sample
|
||||
|
||||
conditional_pred = noise_pred
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# perform guidance
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2, dim=0)
|
||||
conditional_pred = noise_pred_text
|
||||
if detach_unconditional:
|
||||
noise_pred_uncond = noise_pred_uncond.detach()
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
@@ -1141,6 +1149,8 @@ class StableDiffusion:
|
||||
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
if return_conditional_pred:
|
||||
return noise_pred, conditional_pred
|
||||
return noise_pred
|
||||
|
||||
def step_scheduler(self, model_input, latent_input, timestep_tensor, noise_scheduler=None):
|
||||
@@ -1187,23 +1197,30 @@ class StableDiffusion:
|
||||
bleed_ratio: float = 0.5,
|
||||
bleed_latents: torch.FloatTensor = None,
|
||||
is_input_scaled=False,
|
||||
return_first_prediction=False,
|
||||
**kwargs,
|
||||
):
|
||||
timesteps_to_run = self.noise_scheduler.timesteps[start_timesteps:total_timesteps]
|
||||
|
||||
first_prediction = None
|
||||
|
||||
for timestep in tqdm(timesteps_to_run, leave=False):
|
||||
timestep = timestep.unsqueeze_(0)
|
||||
noise_pred = self.predict_noise(
|
||||
noise_pred, conditional_pred = self.predict_noise(
|
||||
latents,
|
||||
text_embeddings,
|
||||
timestep,
|
||||
guidance_scale=guidance_scale,
|
||||
add_time_ids=add_time_ids,
|
||||
is_input_scaled=is_input_scaled,
|
||||
return_conditional_pred=True,
|
||||
**kwargs,
|
||||
)
|
||||
# some schedulers need to run separately, so do that. (euler for example)
|
||||
|
||||
if return_first_prediction and first_prediction is None:
|
||||
first_prediction = conditional_pred
|
||||
|
||||
latents = self.step_scheduler(noise_pred, latents, timestep)
|
||||
|
||||
# if not last step, and bleeding, bleed in some latents
|
||||
@@ -1214,6 +1231,8 @@ class StableDiffusion:
|
||||
is_input_scaled = False
|
||||
|
||||
# return latents_steps
|
||||
if return_first_prediction:
|
||||
return latents, first_prediction
|
||||
return latents
|
||||
|
||||
def encode_prompt(
|
||||
@@ -1311,7 +1330,10 @@ class StableDiffusion:
|
||||
image_list[i] = Resize((image.shape[1] // 8 * 8, image.shape[2] // 8 * 8))(image)
|
||||
|
||||
images = torch.stack(image_list)
|
||||
latents = self.vae.encode(images).latent_dist.sample()
|
||||
if isinstance(self.vae, AutoencoderTiny):
|
||||
latents = self.vae.encode(images, return_dict=False)[0]
|
||||
else:
|
||||
latents = self.vae.encode(images).latent_dist.sample()
|
||||
# latents = self.vae.encode(images, return_dict=False)[0]
|
||||
latents = latents * self.vae.config['scaling_factor']
|
||||
latents = latents.to(device, dtype=dtype)
|
||||
|
||||
Reference in New Issue
Block a user