Added IP adapter training. Not functioning correctly yet

This commit is contained in:
Jaret Burkett
2023-09-24 02:39:43 -06:00
parent 19255cdc7c
commit 830e87cb87
9 changed files with 336 additions and 53 deletions

3
.gitmodules vendored
View File

@@ -7,3 +7,6 @@
[submodule "repositories/batch_annotator"]
path = repositories/batch_annotator
url = https://github.com/ostris/batch-annotator
[submodule "repositories/ipadapter"]
path = repositories/ipadapter
url = https://github.com/tencent-ailab/IP-Adapter.git

View File

@@ -2,9 +2,11 @@ import os.path
from collections import OrderedDict
from PIL import Image
from diffusers import T2IAdapter
from torch.utils.data import DataLoader
from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO
from toolkit.ip_adapter import IPAdapter
from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds
from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
from toolkit.train_tools import get_torch_dtype, apply_snr_weight
@@ -115,13 +117,19 @@ class SDTrainer(BaseSDTrainProcess):
conditional_embeds = conditional_embeds.detach()
# flush()
pred_kwargs = {}
if self.adapter:
if self.adapter and isinstance(self.adapter, T2IAdapter):
down_block_additional_residuals = self.adapter(adapter_images)
down_block_additional_residuals = [
sample.to(dtype=dtype) for sample in down_block_additional_residuals
]
pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals
if self.adapter and isinstance(self.adapter, IPAdapter):
with torch.no_grad():
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(adapter_images)
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
noise_pred = self.sd.predict_noise(
latents=noisy_latents.to(self.device_torch, dtype=dtype),
conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
@@ -164,6 +172,9 @@ class SDTrainer(BaseSDTrainProcess):
loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
loss = loss.mean()
# check if nan
if torch.isnan(loss):
raise ValueError("loss is nan")
# IMPORTANT if gradient checkpointing do not leave with network when doing backward
# it will destroy the gradients. This is because the network is a context manager

View File

@@ -15,6 +15,7 @@ from toolkit.basic import value_map
from toolkit.data_loader import get_dataloader_from_datasets
from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
from toolkit.embedding import Embedding
from toolkit.ip_adapter import IPAdapter
from toolkit.lora_special import LoRASpecialNetwork
from toolkit.lycoris_special import LycorisSpecialNetwork
from toolkit.network_mixins import Network
@@ -22,7 +23,8 @@ from toolkit.optimizer import get_optimizer
from toolkit.paths import CONFIG_ROOT
from toolkit.progress_bar import ToolkitProgressBar
from toolkit.sampler import get_sampler
from toolkit.saving import save_t2i_from_diffusers, load_t2i_model
from toolkit.saving import save_t2i_from_diffusers, load_t2i_model, save_ip_adapter_from_diffusers, \
load_ip_adapter_model
from toolkit.scheduler import get_lr_scheduler
from toolkit.sd_device_states_presets import get_train_sd_device_state_preset
@@ -118,9 +120,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
# to hold network if there is one
self.network: Union[Network, None] = None
self.adapter: Union[T2IAdapter, None] = None
self.adapter: Union[T2IAdapter, IPAdapter, None] = None
self.embedding: Union[Embedding, None] = None
is_training_adapter = self.adapter_config is not None and self.adapter_config.train
# get the device state preset based on what we are training
self.train_device_state_preset = get_train_sd_device_state_preset(
device=self.device_torch,
@@ -128,17 +132,17 @@ class BaseSDTrainProcess(BaseTrainProcess):
train_text_encoder=self.train_config.train_text_encoder,
cached_latents=self.is_latents_cached,
train_lora=self.network_config is not None,
train_adapter=self.adapter_config is not None,
train_adapter=is_training_adapter,
train_embedding=self.embed_config is not None,
)
# fine_tuning here is for training actual SD network, not LoRA, embeddings, etc. it is (Dreambooth, etc)
self.is_fine_tuning = True
if self.network_config is not None or self.adapter_config is not None or self.embed_config is not None:
if self.network_config is not None or is_training_adapter or self.embed_config is not None:
self.is_fine_tuning = False
self.named_lora = False
if self.embed_config is not None or self.adapter_config is not None:
if self.embed_config is not None or is_training_adapter:
self.named_lora = True
def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]):
# override in subclass
@@ -179,7 +183,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
)
extra_args = {}
if self.adapter_config is not None:
if self.adapter_config is not None and self.adapter_config.test_img_path is not None:
extra_args['adapter_image_path'] = self.adapter_config.test_img_path
gen_img_config_list.append(GenerateImageConfig(
@@ -318,22 +322,33 @@ class BaseSDTrainProcess(BaseTrainProcess):
emb_file_path = os.path.splitext(emb_file_path)[0] + ".pt"
self.embedding.save(emb_file_path)
if self.adapter is not None:
if self.adapter is not None and self.adapter_config.train:
adapter_name = self.job.name
if self.network_config is not None or self.embedding is not None:
# add _lora to name
adapter_name += '_t2i'
if self.adapter_config.type == 't2i':
adapter_name += '_t2i'
else:
adapter_name += '_ip'
filename = f'{adapter_name}{step_num}.safetensors'
file_path = os.path.join(self.save_root, filename)
# save adapter
state_dict = self.adapter.state_dict()
save_t2i_from_diffusers(
state_dict,
output_file=file_path,
meta=save_meta,
dtype=get_torch_dtype(self.save_config.dtype)
)
if self.adapter_config.type == 't2i':
save_t2i_from_diffusers(
state_dict,
output_file=file_path,
meta=save_meta,
dtype=get_torch_dtype(self.save_config.dtype)
)
else:
save_ip_adapter_from_diffusers(
state_dict,
output_file=file_path,
meta=save_meta,
dtype=get_torch_dtype(self.save_config.dtype)
)
else:
self.sd.save(
file_path,
@@ -527,6 +542,50 @@ class BaseSDTrainProcess(BaseTrainProcess):
return noisy_latents, noise, timesteps, conditioned_prompts, imgs
def setup_adapter(self):
dtype = get_torch_dtype(self.train_config.dtype)
is_t2i = self.adapter_config.type == 't2i'
if is_t2i:
self.adapter = T2IAdapter(
in_channels=self.adapter_config.in_channels,
channels=self.adapter_config.channels,
num_res_blocks=self.adapter_config.num_res_blocks,
downscale_factor=self.adapter_config.downscale_factor,
adapter_type=self.adapter_config.adapter_type,
)
else:
self.adapter = IPAdapter(
sd=self.sd,
adapter_config=self.adapter_config,
)
self.adapter.to(self.device_torch, dtype=dtype)
# t2i adapter
suffix = 't2i' if is_t2i else 'ip'
adapter_name = self.name
if self.network_config is not None:
adapter_name = f"{adapter_name}_{suffix}"
latest_save_path = self.get_latest_save_path(adapter_name)
if latest_save_path is not None:
# load adapter from path
print(f"Loading adapter from {latest_save_path}")
if is_t2i:
loaded_state_dict = load_t2i_model(
latest_save_path,
self.device,
dtype=dtype
)
else:
loaded_state_dict = load_ip_adapter_model(
latest_save_path,
self.device,
dtype=dtype
)
self.adapter.load_state_dict(loaded_state_dict)
if self.adapter_config.train:
self.load_training_state_from_metadata(latest_save_path)
# set trainable params
self.sd.adapter = self.adapter
def run(self):
# torch.autograd.set_detect_anomaly(True)
# run base process run
@@ -741,35 +800,12 @@ class BaseSDTrainProcess(BaseTrainProcess):
flush()
if self.adapter_config is not None:
self.adapter = T2IAdapter(
in_channels=self.adapter_config.in_channels,
channels=self.adapter_config.channels,
num_res_blocks=self.adapter_config.num_res_blocks,
downscale_factor=self.adapter_config.downscale_factor,
adapter_type=self.adapter_config.adapter_type,
)
self.adapter.to(self.device_torch, dtype=dtype)
# t2i adapter
adapter_name = self.name
if self.network_config is not None:
adapter_name = f"{adapter_name}_t2i"
latest_save_path = self.get_latest_save_path(adapter_name)
if latest_save_path is not None:
# load adapter from path
print(f"Loading adapter from {latest_save_path}")
loaded_state_dict = load_t2i_model(
latest_save_path,
self.device,
dtype=dtype
)
self.adapter.load_state_dict(loaded_state_dict)
self.load_training_state_from_metadata(latest_save_path)
self.setup_adapter()
# set trainable params
params.append({
'params': self.adapter.parameters(),
'lr': self.train_config.adapter_lr
})
self.sd.adapter = self.adapter
flush()
params = self.load_additional_training_modules(params)
@@ -785,6 +821,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
unet_lr=self.train_config.lr,
default_lr=self.train_config.lr
)
# we may be using it for prompt injections
if self.adapter_config is not None:
self.setup_adapter()
flush()
### HOOK ###
params = self.hook_add_extra_train_params(params)

View File

@@ -61,8 +61,11 @@ class NetworkConfig:
self.dropout: Union[float, None] = kwargs.get('dropout', None)
AdapterTypes = Literal['t2i', 'ip', 'ip+']
class AdapterConfig:
def __init__(self, **kwargs):
self.type: AdapterTypes = kwargs.get('type', 't2i') # t2i, ip
self.in_channels: int = kwargs.get('in_channels', 3)
self.channels: List[int] = kwargs.get('channels', [320, 640, 1280, 1280])
self.num_res_blocks: int = kwargs.get('num_res_blocks', 2)
@@ -70,6 +73,9 @@ class AdapterConfig:
self.adapter_type: str = kwargs.get('adapter_type', 'full_adapter')
self.image_dir: str = kwargs.get('image_dir', None)
self.test_img_path: str = kwargs.get('test_img_path', None)
self.train: str = kwargs.get('train', False)
self.image_encoder_path: str = kwargs.get('image_encoder_path', None)
self.name_or_path = kwargs.get('name_or_path', None)
class EmbeddingConfig:

157
toolkit/ip_adapter.py Normal file
View File

@@ -0,0 +1,157 @@
import torch
import sys
from PIL import Image
from torch.nn import Parameter
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from toolkit.paths import REPOS_ROOT
from toolkit.train_tools import get_torch_dtype
sys.path.append(REPOS_ROOT)
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List
from collections import OrderedDict
from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor
from ipadapter.ip_adapter.ip_adapter import ImageProjModel
from ipadapter.ip_adapter.resampler import Resampler
from toolkit.config_modules import AdapterConfig
from toolkit.prompt_utils import PromptEmbeds
import weakref
if TYPE_CHECKING:
from toolkit.stable_diffusion_model import StableDiffusion
# loosely based on # ref https://github.com/tencent-ailab/IP-Adapter/blob/main/tutorial_train.py
class IPAdapter(torch.nn.Module):
"""IP-Adapter"""
def __init__(self, sd: 'StableDiffusion', adapter_config: 'AdapterConfig'):
super().__init__()
self.config = adapter_config
self.sd_ref: weakref.ref = weakref.ref(sd)
self.clip_image_processor = CLIPImageProcessor()
self.device = self.sd_ref().unet.device
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(adapter_config.image_encoder_path)
if adapter_config.type == 'ip':
# ip-adapter
image_proj_model = ImageProjModel(
cross_attention_dim=sd.unet.config['cross_attention_dim'],
clip_embeddings_dim=self.image_encoder.config.projection_dim,
clip_extra_context_tokens=4,
)
elif adapter_config.type == 'ip+':
# ip-adapter-plus
num_tokens = 16
image_proj_model = Resampler(
dim=sd.unet.config['cross_attention_dim'],
depth=4,
dim_head=64,
heads=12,
num_queries=num_tokens,
embedding_dim=self.image_encoder.config.hidden_size,
output_dim=sd.unet.config['cross_attention_dim'],
ff_mult=4
)
else:
raise ValueError(f"unknown adapter type: {adapter_config.type}")
# init adapter modules
attn_procs = {}
unet_sd = sd.unet.state_dict()
for name in sd.unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else sd.unet.config['cross_attention_dim']
if name.startswith("mid_block"):
hidden_size = sd.unet.config['block_out_channels'][-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(sd.unet.config['block_out_channels']))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = sd.unet.config['block_out_channels'][block_id]
else:
# they didnt have this, but would lead to undefined below
raise ValueError(f"unknown attn processor name: {name}")
if cross_attention_dim is None:
attn_procs[name] = AttnProcessor()
else:
layer_name = name.split(".processor")[0]
weights = {
"to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
"to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
}
attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
attn_procs[name].load_state_dict(weights)
sd.unet.set_attn_processor(attn_procs)
adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())
sd.adapter = self
self.unet_ref: weakref.ref = weakref.ref(sd.unet)
self.image_proj_model = image_proj_model
self.adapter_modules = adapter_modules
def to(self, *args, **kwargs):
super().to(*args, **kwargs)
self.image_encoder.to(*args, **kwargs)
self.image_proj_model.to(*args, **kwargs)
self.adapter_modules.to(*args, **kwargs)
return self
def load_ip_adapter(self, state_dict: Union[OrderedDict, dict]):
self.image_proj_model.load_state_dict(state_dict["image_proj"])
ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
ip_layers.load_state_dict(state_dict["ip_adapter"])
def state_dict(self) -> OrderedDict:
state_dict = OrderedDict()
state_dict["image_proj"] = self.image_proj_model.state_dict()
state_dict["ip_adapter"] = self.adapter_modules.state_dict()
return state_dict
def set_scale(self, scale):
for attn_processor in self.pipe.unet.attn_processors.values():
if isinstance(attn_processor, IPAttnProcessor):
attn_processor.scale = scale
@torch.no_grad()
def get_clip_image_embeds_from_pil(self, pil_image: Union[Image.Image, List[Image.Image]], drop=False) -> torch.Tensor:
# todo: add support for sdxl
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image = clip_image.to(self.device, dtype=torch.float16)
if drop:
clip_image = clip_image * 0
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
return clip_image_embeds
@torch.no_grad()
def get_clip_image_embeds_from_tensors(self, tensors_0_1: torch.Tensor, drop=False) -> torch.Tensor:
# tensors should be 0-1
# todo: add support for sdxl
if tensors_0_1.ndim == 3:
tensors_0_1 = tensors_0_1.unsqueeze(0)
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
clip_image = self.clip_image_processor(images=tensors_0_1, return_tensors="pt", do_resize=False).pixel_values
clip_image = clip_image.to(self.device, dtype=torch.float16)
if drop:
clip_image = clip_image * 0
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
return clip_image_embeds
# use drop for prompt dropout, or negatives
def forward(self, embeddings: PromptEmbeds, clip_image_embeds: torch.Tensor) -> PromptEmbeds:
clip_image_embeds = clip_image_embeds.detach()
clip_image_embeds = clip_image_embeds.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
image_prompt_embeds = self.image_proj_model(clip_image_embeds.detach())
embeddings.text_embeds = torch.cat([embeddings.text_embeds, image_prompt_embeds], dim=1)
return embeddings
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
for attn_processor in self.adapter_modules:
yield from attn_processor.parameters(recurse)
yield from self.image_proj_model.parameters(recurse)
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict)
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict)

View File

@@ -306,6 +306,7 @@ class ToolkitNetworkMixin:
extra_dict = None
return extra_dict
@torch.no_grad()
def _update_torch_multiplier(self: Network):
# builds a tensor for fast usage in the forward pass of the network modules
# without having to set it in every single module every time it changes

View File

@@ -193,3 +193,42 @@ def load_t2i_model(
# todo see if we need to convert dict
converted_state_dict[key] = value.detach().to(device, dtype=dtype)
return converted_state_dict
IP_ADAPTER_MODULES = ['image_proj', 'ip_adapter']
def save_ip_adapter_from_diffusers(
combined_state_dict: 'OrderedDict',
output_file: str,
meta: 'OrderedDict',
dtype=get_torch_dtype('fp16'),
):
# todo: test compatibility with non diffusers
converted_state_dict = OrderedDict()
for module_name, state_dict in combined_state_dict.items():
for key, value in state_dict.items():
converted_state_dict[f"{module_name}.{key}"] = value.detach().to('cpu', dtype=dtype)
# make sure parent folder exists
os.makedirs(os.path.dirname(output_file), exist_ok=True)
save_file(converted_state_dict, output_file, metadata=meta)
def load_ip_adapter_model(
path_to_file,
device: Union[str] = 'cpu',
dtype: torch.dtype = torch.float32
):
# check if it is safetensors or checkpoint
if path_to_file.endswith('.safetensors'):
raw_state_dict = load_file(path_to_file, device)
combined_state_dict = OrderedDict()
for combo_key, value in raw_state_dict.items():
key_split = combo_key.split('.')
module_name = key_split.pop(0)
if module_name not in combined_state_dict:
combined_state_dict[module_name] = OrderedDict()
combined_state_dict[module_name]['.'.join(key_split)] = value.detach().to(device, dtype=dtype)
return combined_state_dict
else:
return torch.load(path_to_file, map_location=device)

View File

@@ -13,8 +13,9 @@ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import
from safetensors.torch import save_file, load_file
from torch.nn import Parameter
from tqdm import tqdm
from torchvision.transforms import Resize
from torchvision.transforms import Resize, transforms
from toolkit.ip_adapter import IPAdapter
from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \
convert_vae_state_dict, load_vae
from toolkit import train_tools
@@ -123,7 +124,7 @@ class StableDiffusion:
# to hold network if there is one
self.network = None
self.adapter: Union['T2IAdapter', None] = None
self.adapter: Union['T2IAdapter', 'IPAdapter', None] = None
self.is_xl = model_config.is_xl
self.is_v2 = model_config.is_v2
@@ -302,12 +303,13 @@ class StableDiffusion:
Pipe = StableDiffusionPipeline
extra_args = {}
if self.adapter:
if self.is_xl:
Pipe = StableDiffusionXLAdapterPipeline
else:
Pipe = StableDiffusionAdapterPipeline
extra_args['adapter'] = self.adapter
if self.adapter is not None:
if isinstance(self.adapter, T2IAdapter):
if self.is_xl:
Pipe = StableDiffusionXLAdapterPipeline
else:
Pipe = StableDiffusionAdapterPipeline
extra_args['adapter'] = self.adapter
# TODO add clip skip
if self.is_xl:
@@ -358,11 +360,19 @@ class StableDiffusion:
gen_config = image_configs[i]
extra = {}
if gen_config.adapter_image_path is not None:
if self.adapter is not None and gen_config.adapter_image_path is not None:
validation_image = Image.open(gen_config.adapter_image_path).convert("RGB")
validation_image = validation_image.resize((gen_config.width * 2, gen_config.height * 2))
extra['image'] = validation_image
extra['adapter_conditioning_scale'] = 1.0
if isinstance(self.adapter, T2IAdapter):
# not sure why this is double??
validation_image = validation_image.resize((gen_config.width * 2, gen_config.height * 2))
extra['image'] = validation_image
extra['adapter_conditioning_scale'] = 1.0
if isinstance(self.adapter, IPAdapter):
transform = transforms.Compose([
transforms.Resize(gen_config.width, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.PILToTensor(),
])
validation_image = transform(validation_image)
if self.network is not None:
self.network.multiplier = gen_config.network_multiplier
@@ -382,6 +392,14 @@ class StableDiffusion:
unconditional_embeds,
)
if self.adapter is not None and isinstance(self.adapter, IPAdapter) and gen_config.adapter_image_path is not None:
# apply the image projection
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image)
unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image, True)
conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds)
unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds)
# todo do we disable text encoder here as well if disabled for model, or only do that for training?
if self.is_xl:
# fix guidance rescale for sdxl
@@ -931,10 +949,18 @@ class StableDiffusion:
'requires_grad': self.text_encoder.text_model.final_layer_norm.weight.requires_grad
}
if self.adapter is not None:
if isinstance(self.adapter, IPAdapter):
requires_grad = self.adapter.adapter_modules.training
adapter_device = self.unet.device
elif isinstance(self.adapter, T2IAdapter):
requires_grad = self.adapter.adapter.conv_in.weight.requires_grad
adapter_device = self.adapter.device
else:
raise ValueError(f"Unknown adapter type: {type(self.adapter)}")
self.device_state['adapter'] = {
'training': self.adapter.training,
'device': self.adapter.device,
'requires_grad': self.adapter.adapter.conv_in.weight.requires_grad,
'device': adapter_device,
'requires_grad': requires_grad,
}
def restore_device_state(self):