mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Many bug fixes. Ip adapter bug fixes. Added noise to unconditional, it works better. added an ilora adapter for 1 shotting LoRAs
This commit is contained in:
@@ -286,9 +286,10 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
)
|
||||
prior_loss = prior_loss * prior_mask_multiplier * self.train_config.inverted_mask_prior_multiplier
|
||||
if torch.isnan(prior_loss).any():
|
||||
raise ValueError("Prior loss is nan")
|
||||
|
||||
prior_loss = prior_loss.mean([1, 2, 3])
|
||||
print("Prior loss is nan")
|
||||
prior_loss = None
|
||||
else:
|
||||
prior_loss = prior_loss.mean([1, 2, 3])
|
||||
# loss = loss + prior_loss
|
||||
loss = loss.mean([1, 2, 3])
|
||||
if prior_loss is not None:
|
||||
@@ -992,6 +993,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
|
||||
if self.adapter and isinstance(self.adapter, IPAdapter):
|
||||
with self.timer('encode_adapter_embeds'):
|
||||
# number of images to do if doing a quad image
|
||||
quad_count = random.randint(1, 4)
|
||||
image_size = self.adapter.input_size
|
||||
if is_reg:
|
||||
# we will zero it out in the img embedder
|
||||
@@ -1004,7 +1007,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
clip_images,
|
||||
drop=True,
|
||||
is_training=True,
|
||||
has_been_preprocessed=True
|
||||
has_been_preprocessed=True,
|
||||
quad_count=quad_count
|
||||
)
|
||||
if self.train_config.do_cfg:
|
||||
unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
|
||||
@@ -1014,13 +1018,15 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
).detach(),
|
||||
is_training=True,
|
||||
drop=True,
|
||||
has_been_preprocessed=True
|
||||
has_been_preprocessed=True,
|
||||
quad_count=quad_count
|
||||
)
|
||||
elif has_clip_image:
|
||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
|
||||
clip_images.detach().to(self.device_torch, dtype=dtype),
|
||||
is_training=True,
|
||||
has_been_preprocessed=True
|
||||
has_been_preprocessed=True,
|
||||
quad_count=quad_count
|
||||
)
|
||||
if self.train_config.do_cfg:
|
||||
unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
|
||||
@@ -1030,7 +1036,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
).detach(),
|
||||
is_training=True,
|
||||
drop=True,
|
||||
has_been_preprocessed=True
|
||||
has_been_preprocessed=True,
|
||||
quad_count=quad_count
|
||||
)
|
||||
else:
|
||||
raise ValueError("Adapter images now must be loaded with dataloader or be a reg image")
|
||||
@@ -1152,7 +1159,8 @@ class SDTrainer(BaseSDTrainProcess):
|
||||
)
|
||||
# check if nan
|
||||
if torch.isnan(loss):
|
||||
raise ValueError("loss is nan")
|
||||
print("loss is nan")
|
||||
loss = torch.zeros_like(loss).requires_grad_(True)
|
||||
|
||||
with self.timer('backward'):
|
||||
# todo we have multiplier seperated. works for now as res are not in same batch, but need to change
|
||||
|
||||
@@ -2,6 +2,7 @@ import copy
|
||||
import glob
|
||||
import inspect
|
||||
import json
|
||||
import random
|
||||
import shutil
|
||||
from collections import OrderedDict
|
||||
import os
|
||||
@@ -423,7 +424,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
adapter_name += '_t2i'
|
||||
elif self.adapter_config.type == 'clip':
|
||||
adapter_name += '_clip'
|
||||
elif self.adapter_config.type == 'ip':
|
||||
elif self.adapter_config.type.startswith('ip'):
|
||||
adapter_name += '_ip'
|
||||
else:
|
||||
adapter_name += '_adapter'
|
||||
@@ -444,7 +445,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
state_dict,
|
||||
output_file=file_path,
|
||||
meta=save_meta,
|
||||
dtype=get_torch_dtype(self.save_config.dtype)
|
||||
dtype=get_torch_dtype(self.save_config.dtype),
|
||||
direct_save=self.adapter_config.train_only_image_encoder
|
||||
)
|
||||
else:
|
||||
if self.save_config.save_format == "diffusers":
|
||||
@@ -1010,7 +1012,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
loaded_state_dict = load_ip_adapter_model(
|
||||
latest_save_path,
|
||||
self.device,
|
||||
dtype=dtype
|
||||
dtype=dtype,
|
||||
direct_load=self.adapter_config.train_only_image_encoder
|
||||
)
|
||||
self.adapter.load_state_dict(loaded_state_dict)
|
||||
else:
|
||||
@@ -1146,14 +1149,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
self.snr_gos.scale.data = torch.tensor(json_data['scale'], device=self.device_torch)
|
||||
self.snr_gos.gamma.data = torch.tensor(json_data['gamma'], device=self.device_torch)
|
||||
|
||||
# load the adapters before the dataset as they may use the clip encoders
|
||||
if self.adapter_config is not None:
|
||||
self.setup_adapter()
|
||||
flush()
|
||||
if not self.is_fine_tuning:
|
||||
if self.network_config is not None:
|
||||
# TODO should we completely switch to LycorisSpecialNetwork?
|
||||
network_kwargs = {}
|
||||
network_kwargs = self.network_config.network_kwargs
|
||||
is_lycoris = False
|
||||
is_lorm = self.network_config.type.lower() == 'lorm'
|
||||
# default to LoCON if there are any conv layers or if it is named
|
||||
@@ -1279,12 +1279,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
flush()
|
||||
|
||||
if self.adapter_config is not None:
|
||||
# self.setup_adapter()
|
||||
# set trainable params
|
||||
params.append({
|
||||
'params': self.adapter.parameters(),
|
||||
'lr': self.train_config.adapter_lr
|
||||
})
|
||||
self.setup_adapter()
|
||||
if self.adapter_config.train:
|
||||
# set trainable params
|
||||
params.append({
|
||||
'params': self.adapter.parameters(),
|
||||
'lr': self.train_config.adapter_lr
|
||||
})
|
||||
|
||||
if self.train_config.gradient_checkpointing:
|
||||
self.adapter.enable_gradient_checkpointing()
|
||||
flush()
|
||||
|
||||
params = self.load_additional_training_modules(params)
|
||||
@@ -1306,7 +1310,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
refiner_lr=self.train_config.refiner_lr,
|
||||
)
|
||||
# we may be using it for prompt injections
|
||||
if self.adapter_config is not None:
|
||||
if self.adapter_config is not None and self.adapter is None:
|
||||
self.setup_adapter()
|
||||
flush()
|
||||
### HOOK ###
|
||||
@@ -1379,7 +1383,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
# sample first
|
||||
if self.train_config.skip_first_sample:
|
||||
self.print("Skipping first sample due to config setting")
|
||||
elif self.step_num <= 1:
|
||||
elif self.step_num <= 1 or self.train_config.force_first_sample:
|
||||
self.print("Generating baseline samples before training")
|
||||
self.sample(self.step_num)
|
||||
|
||||
@@ -1422,6 +1426,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
||||
start_step_num = self.step_num
|
||||
did_first_flush = False
|
||||
for step in range(start_step_num, self.train_config.steps):
|
||||
if self.train_config.do_random_cfg:
|
||||
self.train_config.do_cfg = True
|
||||
self.train_config.cfg_scale = value_map(random.random(), 0, 1, 1.0, self.train_config.max_cfg_scale)
|
||||
self.step_num = step
|
||||
# default to true so various things can turn it off
|
||||
self.is_grad_accumulation_step = True
|
||||
|
||||
@@ -113,6 +113,7 @@ class NetworkConfig:
|
||||
self.linear_alpha: float = kwargs.get('linear_alpha', self.alpha)
|
||||
self.conv_alpha: float = kwargs.get('conv_alpha', self.conv)
|
||||
self.dropout: Union[float, None] = kwargs.get('dropout', None)
|
||||
self.network_kwargs: dict = kwargs.get('network_kwargs', {})
|
||||
|
||||
self.lorm_config: Union[LoRMConfig, None] = None
|
||||
lorm = kwargs.get('lorm', None)
|
||||
@@ -153,10 +154,14 @@ class AdapterConfig:
|
||||
|
||||
self.num_tokens: int = num_tokens
|
||||
self.train_image_encoder: bool = kwargs.get('train_image_encoder', False)
|
||||
self.train_only_image_encoder: bool = kwargs.get('train_only_image_encoder', False)
|
||||
if self.train_only_image_encoder:
|
||||
self.train_image_encoder = True
|
||||
self.image_encoder_arch: str = kwargs.get('image_encoder_arch', 'clip') # clip vit vit_hybrid, safe
|
||||
self.safe_reducer_channels: int = kwargs.get('safe_reducer_channels', 512)
|
||||
self.safe_channels: int = kwargs.get('safe_channels', 2048)
|
||||
self.safe_tokens: int = kwargs.get('safe_tokens', 8)
|
||||
self.quad_image: bool = kwargs.get('quad_image', False)
|
||||
|
||||
# clip vision
|
||||
self.trigger = kwargs.get('trigger', 'tri993r')
|
||||
@@ -211,6 +216,7 @@ class TrainConfig:
|
||||
self.learnable_snr_gos = kwargs.get('learnable_snr_gos', False)
|
||||
self.noise_offset = kwargs.get('noise_offset', 0.0)
|
||||
self.skip_first_sample = kwargs.get('skip_first_sample', False)
|
||||
self.force_first_sample = kwargs.get('force_first_sample', False)
|
||||
self.gradient_checkpointing = kwargs.get('gradient_checkpointing', True)
|
||||
self.weight_jitter = kwargs.get('weight_jitter', 0.0)
|
||||
self.merge_network_on_save = kwargs.get('merge_network_on_save', False)
|
||||
@@ -275,7 +281,9 @@ class TrainConfig:
|
||||
|
||||
self.dynamic_noise_offset = kwargs.get('dynamic_noise_offset', False)
|
||||
self.do_cfg = kwargs.get('do_cfg', False)
|
||||
self.do_random_cfg = kwargs.get('do_random_cfg', False)
|
||||
self.cfg_scale = kwargs.get('cfg_scale', 1.0)
|
||||
self.max_cfg_scale = kwargs.get('max_cfg_scale', self.cfg_scale)
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
|
||||
@@ -7,6 +7,7 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from toolkit.models.clip_fusion import CLIPFusionModule
|
||||
from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
|
||||
from toolkit.models.ilora import InstantLoRAModule
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
from toolkit.photomaker import PhotoMakerIDEncoder, FuseModule, PhotoMakerCLIPEncoder
|
||||
from toolkit.saving import load_ip_adapter_model
|
||||
@@ -74,6 +75,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
self.clip_image_processor = self.image_processor
|
||||
|
||||
self.clip_fusion_module: CLIPFusionModule = None
|
||||
self.ilora_module: InstantLoRAModule = None
|
||||
|
||||
self.setup_adapter()
|
||||
|
||||
@@ -106,6 +108,15 @@ class CustomAdapter(torch.nn.Module):
|
||||
vision_hidden_size=self.vision_encoder.config.hidden_size,
|
||||
vision_tokens=vision_tokens
|
||||
)
|
||||
elif self.adapter_type == 'ilora':
|
||||
vision_tokens = ((self.vision_encoder.config.image_size // self.vision_encoder.config.patch_size) ** 2)
|
||||
if self.config.image_encoder_arch == 'clip':
|
||||
vision_tokens = vision_tokens + 1
|
||||
self.ilora_module = InstantLoRAModule(
|
||||
vision_tokens=vision_tokens,
|
||||
vision_hidden_size=self.vision_encoder.config.hidden_size,
|
||||
sd=self.sd_ref()
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown adapter type: {self.adapter_type}")
|
||||
|
||||
@@ -283,6 +294,9 @@ class CustomAdapter(torch.nn.Module):
|
||||
if 'fuse_module' in state_dict:
|
||||
self.fuse_module.load_state_dict(state_dict['fuse_module'], strict=strict)
|
||||
|
||||
if 'ilora' in state_dict:
|
||||
self.ilora_module.load_state_dict(state_dict['ilora'], strict=strict)
|
||||
|
||||
pass
|
||||
|
||||
def state_dict(self) -> OrderedDict:
|
||||
@@ -301,6 +315,11 @@ class CustomAdapter(torch.nn.Module):
|
||||
state_dict["vision_encoder"] = self.vision_encoder.state_dict()
|
||||
state_dict["clip_fusion"] = self.clip_fusion_module.state_dict()
|
||||
return state_dict
|
||||
elif self.adapter_type == 'ilora':
|
||||
if self.config.train_image_encoder:
|
||||
state_dict["vision_encoder"] = self.vision_encoder.state_dict()
|
||||
state_dict["ilora"] = self.ilora_module.state_dict()
|
||||
return state_dict
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -309,7 +328,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
prompt: Union[List[str], str],
|
||||
is_unconditional: bool = False,
|
||||
):
|
||||
if self.adapter_type == 'clip_fusion':
|
||||
if self.adapter_type == 'clip_fusion' or self.adapter_type == 'ilora':
|
||||
return prompt
|
||||
elif self.adapter_type == 'photo_maker':
|
||||
if is_unconditional:
|
||||
@@ -408,7 +427,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
has_been_preprocessed=False,
|
||||
is_unconditional=False
|
||||
) -> PromptEmbeds:
|
||||
if self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion':
|
||||
if self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion' or self.adapter_type == 'ilora':
|
||||
if is_unconditional:
|
||||
# we dont condition the negative embeds for photo maker
|
||||
return prompt_embeds.clone()
|
||||
@@ -459,7 +478,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
self.token_mask
|
||||
)
|
||||
return prompt_embeds
|
||||
elif self.adapter_type == 'clip_fusion':
|
||||
elif self.adapter_type == 'clip_fusion' or self.adapter_type == 'ilora':
|
||||
with torch.set_grad_enabled(is_training):
|
||||
if is_training and self.config.train_image_encoder:
|
||||
self.vision_encoder.train()
|
||||
@@ -480,11 +499,17 @@ class CustomAdapter(torch.nn.Module):
|
||||
if not is_training or not self.config.train_image_encoder:
|
||||
img_embeds = img_embeds.detach()
|
||||
|
||||
prompt_embeds.text_embeds = self.clip_fusion_module(
|
||||
prompt_embeds.text_embeds,
|
||||
img_embeds
|
||||
)
|
||||
return prompt_embeds
|
||||
if self.adapter_type == 'ilora':
|
||||
self.ilora_module.img_embeds = img_embeds
|
||||
|
||||
return prompt_embeds
|
||||
else:
|
||||
|
||||
prompt_embeds.text_embeds = self.clip_fusion_module(
|
||||
prompt_embeds.text_embeds,
|
||||
img_embeds
|
||||
)
|
||||
return prompt_embeds
|
||||
|
||||
|
||||
else:
|
||||
@@ -499,5 +524,9 @@ class CustomAdapter(torch.nn.Module):
|
||||
yield from self.clip_fusion_module.parameters(recurse)
|
||||
if self.config.train_image_encoder:
|
||||
yield from self.vision_encoder.parameters(recurse)
|
||||
elif self.config.type == 'ilora':
|
||||
yield from self.ilora_module.parameters(recurse)
|
||||
if self.config.train_image_encoder:
|
||||
yield from self.vision_encoder.parameters(recurse)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -655,6 +655,21 @@ class ClipImageFileItemDTOMixin:
|
||||
else:
|
||||
self.clip_image_tensor = transforms.ToTensor()(img)
|
||||
|
||||
# random crop
|
||||
# if self.dataset_config.clip_image_random_crop:
|
||||
# # crop up to 20% on all sides. Keep is square
|
||||
# crop_percent = random.randint(0, 20) / 100
|
||||
# crop_width = int(self.clip_image_tensor.shape[2] * crop_percent)
|
||||
# crop_height = int(self.clip_image_tensor.shape[1] * crop_percent)
|
||||
# crop_left = random.randint(0, crop_width)
|
||||
# crop_top = random.randint(0, crop_height)
|
||||
# crop_right = self.clip_image_tensor.shape[2] - crop_width - crop_left
|
||||
# crop_bottom = self.clip_image_tensor.shape[1] - crop_height - crop_top
|
||||
# if len(self.clip_image_tensor.shape) == 3:
|
||||
# self.clip_image_tensor = self.clip_image_tensor[:, crop_top:-crop_bottom, crop_left:-crop_right]
|
||||
# elif len(self.clip_image_tensor.shape) == 4:
|
||||
# self.clip_image_tensor = self.clip_image_tensor[:, :, crop_top:-crop_bottom, crop_left:-crop_right]
|
||||
|
||||
if self.clip_image_processor is not None:
|
||||
# run it
|
||||
tensors_0_1 = self.clip_image_tensor.to(dtype=torch.float16)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import random
|
||||
|
||||
import torch
|
||||
import sys
|
||||
|
||||
@@ -39,6 +41,9 @@ from transformers import ViTHybridImageProcessor, ViTHybridForImageClassificatio
|
||||
|
||||
from transformers import ViTFeatureExtractor, ViTForImageClassification
|
||||
|
||||
# gradient checkpointing
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
@@ -166,6 +171,8 @@ class IPAdapter(torch.nn.Module):
|
||||
self.device = self.sd_ref().unet.device
|
||||
self.preprocessor: Optional[CLIPImagePreProcessor] = None
|
||||
self.input_size = 224
|
||||
self.clip_noise_zero = True
|
||||
self.unconditional: torch.Tensor = None
|
||||
if self.config.image_encoder_arch == 'clip' or self.config.image_encoder_arch == 'clip+':
|
||||
try:
|
||||
self.clip_image_processor = CLIPImageProcessor.from_pretrained(adapter_config.image_encoder_path)
|
||||
@@ -236,6 +243,16 @@ class IPAdapter(torch.nn.Module):
|
||||
|
||||
self.input_size = self.image_encoder.config.image_size
|
||||
|
||||
if self.config.quad_image: # 4x4 image
|
||||
# self.clip_image_processor.config
|
||||
# We do a 3x downscale of the image, so we need to adjust the input size
|
||||
preprocessor_input_size = self.image_encoder.config.image_size * 2
|
||||
|
||||
# update the preprocessor so images come in at the right size
|
||||
self.clip_image_processor.size['shortest_edge'] = preprocessor_input_size
|
||||
self.clip_image_processor.crop_size['height'] = preprocessor_input_size
|
||||
self.clip_image_processor.crop_size['width'] = preprocessor_input_size
|
||||
|
||||
if self.config.image_encoder_arch == 'clip+':
|
||||
# self.clip_image_processor.config
|
||||
# We do a 3x downscale of the image, so we need to adjust the input size
|
||||
@@ -349,6 +366,15 @@ class IPAdapter(torch.nn.Module):
|
||||
self.image_encoder.train()
|
||||
self.image_encoder.requires_grad_(True)
|
||||
|
||||
# premake a unconditional
|
||||
zerod = torch.zeros(1, 3, self.input_size, self.input_size, device=self.device, dtype=torch.float16)
|
||||
self.unconditional = self.clip_image_processor(
|
||||
images=zerod,
|
||||
return_tensors="pt",
|
||||
do_resize=True,
|
||||
do_rescale=False,
|
||||
).pixel_values
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
super().to(*args, **kwargs)
|
||||
self.image_encoder.to(*args, **kwargs)
|
||||
@@ -358,20 +384,23 @@ class IPAdapter(torch.nn.Module):
|
||||
self.preprocessor.to(*args, **kwargs)
|
||||
return self
|
||||
|
||||
def load_ip_adapter(self, state_dict: Union[OrderedDict, dict]):
|
||||
self.image_proj_model.load_state_dict(state_dict["image_proj"])
|
||||
ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
|
||||
ip_layers.load_state_dict(state_dict["ip_adapter"])
|
||||
if self.config.train_image_encoder and 'image_encoder' in state_dict:
|
||||
self.image_encoder.load_state_dict(state_dict["image_encoder"])
|
||||
if self.preprocessor is not None and 'preprocessor' in state_dict:
|
||||
self.preprocessor.load_state_dict(state_dict["preprocessor"])
|
||||
# def load_ip_adapter(self, state_dict: Union[OrderedDict, dict]):
|
||||
# self.image_proj_model.load_state_dict(state_dict["image_proj"])
|
||||
# ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
|
||||
# ip_layers.load_state_dict(state_dict["ip_adapter"])
|
||||
# if self.config.train_image_encoder and 'image_encoder' in state_dict:
|
||||
# self.image_encoder.load_state_dict(state_dict["image_encoder"])
|
||||
# if self.preprocessor is not None and 'preprocessor' in state_dict:
|
||||
# self.preprocessor.load_state_dict(state_dict["preprocessor"])
|
||||
|
||||
# def load_state_dict(self, state_dict: Union[OrderedDict, dict]):
|
||||
# self.load_ip_adapter(state_dict)
|
||||
|
||||
def state_dict(self) -> OrderedDict:
|
||||
state_dict = OrderedDict()
|
||||
if self.config.train_only_image_encoder:
|
||||
return self.image_encoder.state_dict()
|
||||
|
||||
state_dict["image_proj"] = self.image_proj_model.state_dict()
|
||||
state_dict["ip_adapter"] = self.adapter_modules.state_dict()
|
||||
if self.config.train_image_encoder:
|
||||
@@ -402,13 +431,28 @@ class IPAdapter(torch.nn.Module):
|
||||
# clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
||||
# return clip_image_embeds
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
super().to(*args, **kwargs)
|
||||
self.image_encoder.to(*args, **kwargs)
|
||||
self.image_proj_model.to(*args, **kwargs)
|
||||
self.adapter_modules.to(*args, **kwargs)
|
||||
if self.preprocessor is not None:
|
||||
self.preprocessor.to(*args, **kwargs)
|
||||
return self
|
||||
def get_clip_image_embeds_from_tensors(
|
||||
self,
|
||||
tensors_0_1: torch.Tensor,
|
||||
drop=False,
|
||||
is_training=False,
|
||||
has_been_preprocessed=False
|
||||
has_been_preprocessed=False,
|
||||
quad_count=4,
|
||||
) -> torch.Tensor:
|
||||
if self.sd_ref().unet.device != self.device:
|
||||
self.to(self.sd_ref().unet.device)
|
||||
if self.sd_ref().unet.device != self.image_encoder.device:
|
||||
self.to(self.sd_ref().unet.device)
|
||||
if not self.config.train:
|
||||
is_training = False
|
||||
with torch.no_grad():
|
||||
# on training the clip image is created in the dataloader
|
||||
if not has_been_preprocessed:
|
||||
@@ -417,11 +461,19 @@ class IPAdapter(torch.nn.Module):
|
||||
tensors_0_1 = tensors_0_1.unsqueeze(0)
|
||||
# training tensors are 0 - 1
|
||||
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
|
||||
|
||||
# if images are out of this range throw error
|
||||
if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
|
||||
raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
|
||||
tensors_0_1.min(), tensors_0_1.max()
|
||||
))
|
||||
# unconditional
|
||||
if drop:
|
||||
if self.clip_noise_zero:
|
||||
tensors_0_1 = torch.rand_like(tensors_0_1).detach()
|
||||
else:
|
||||
tensors_0_1 = torch.zeros_like(tensors_0_1).detach()
|
||||
# tensors_0_1 = tensors_0_1 * 0
|
||||
clip_image = self.clip_image_processor(
|
||||
images=tensors_0_1,
|
||||
return_tensors="pt",
|
||||
@@ -429,10 +481,42 @@ class IPAdapter(torch.nn.Module):
|
||||
do_rescale=False,
|
||||
).pixel_values
|
||||
else:
|
||||
clip_image = tensors_0_1
|
||||
if drop:
|
||||
# scale the noise down
|
||||
if self.clip_noise_zero:
|
||||
tensors_0_1 = torch.rand_like(tensors_0_1).detach()
|
||||
else:
|
||||
tensors_0_1 = torch.zeros_like(tensors_0_1).detach()
|
||||
# tensors_0_1 = tensors_0_1 * 0
|
||||
mean = torch.tensor(self.clip_image_processor.image_mean).to(
|
||||
self.device, dtype=get_torch_dtype(self.sd_ref().dtype)
|
||||
).detach()
|
||||
std = torch.tensor(self.clip_image_processor.image_std).to(
|
||||
self.device, dtype=get_torch_dtype(self.sd_ref().dtype)
|
||||
).detach()
|
||||
tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0
|
||||
clip_image = (tensors_0_1 - mean.view([1, 3, 1, 1])) / std.view([1, 3, 1, 1])
|
||||
|
||||
else:
|
||||
clip_image = tensors_0_1
|
||||
clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach()
|
||||
if drop:
|
||||
clip_image = clip_image * 0
|
||||
|
||||
if self.config.quad_image:
|
||||
# split the 4x4 grid and stack on batch
|
||||
ci1, ci2 = clip_image.chunk(2, dim=2)
|
||||
ci1, ci3 = ci1.chunk(2, dim=3)
|
||||
ci2, ci4 = ci2.chunk(2, dim=3)
|
||||
to_cat = []
|
||||
for i, ci in enumerate([ci1, ci2, ci3, ci4]):
|
||||
if i < quad_count:
|
||||
to_cat.append(ci)
|
||||
else:
|
||||
break
|
||||
|
||||
clip_image = torch.cat(to_cat, dim=0).detach()
|
||||
|
||||
# if drop:
|
||||
# clip_image = clip_image * 0
|
||||
with torch.set_grad_enabled(is_training):
|
||||
if is_training:
|
||||
self.image_encoder.train()
|
||||
@@ -457,6 +541,20 @@ class IPAdapter(torch.nn.Module):
|
||||
clip_image_embeds = clip_output.hidden_states[-2]
|
||||
else:
|
||||
clip_image_embeds = clip_output.image_embeds
|
||||
|
||||
if self.config.quad_image:
|
||||
# get the outputs of the quat
|
||||
chunks = clip_image_embeds.chunk(quad_count, dim=0)
|
||||
chunk_sum = torch.zeros_like(chunks[0])
|
||||
for chunk in chunks:
|
||||
chunk_sum = chunk_sum + chunk
|
||||
# get the mean of them
|
||||
|
||||
clip_image_embeds = chunk_sum / quad_count
|
||||
|
||||
if not is_training:
|
||||
clip_image_embeds = clip_image_embeds.detach()
|
||||
|
||||
return clip_image_embeds
|
||||
|
||||
# use drop for prompt dropout, or negatives
|
||||
@@ -467,6 +565,9 @@ class IPAdapter(torch.nn.Module):
|
||||
return embeddings
|
||||
|
||||
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
|
||||
if self.config.train_only_image_encoder:
|
||||
yield from self.image_encoder.parameters(recurse)
|
||||
return
|
||||
for attn_processor in self.adapter_modules:
|
||||
yield from attn_processor.parameters(recurse)
|
||||
yield from self.image_proj_model.parameters(recurse)
|
||||
@@ -561,17 +662,26 @@ class IPAdapter(torch.nn.Module):
|
||||
|
||||
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
||||
strict = False
|
||||
try:
|
||||
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict)
|
||||
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("could not load ip adapter weights, trying to merge in weights")
|
||||
self.merge_in_weights(state_dict)
|
||||
if 'ip_adapter' in state_dict:
|
||||
try:
|
||||
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict)
|
||||
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("could not load ip adapter weights, trying to merge in weights")
|
||||
self.merge_in_weights(state_dict)
|
||||
if self.config.train_image_encoder and 'image_encoder' in state_dict:
|
||||
self.image_encoder.load_state_dict(state_dict["image_encoder"], strict=strict)
|
||||
if self.preprocessor is not None and 'preprocessor' in state_dict:
|
||||
self.preprocessor.load_state_dict(state_dict["preprocessor"], strict=strict)
|
||||
|
||||
if self.config.train_only_image_encoder and 'ip_adapter' not in state_dict:
|
||||
# we are loading pure clip weights.
|
||||
self.image_encoder.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
self.image_encoder.gradient_checkpointing = True
|
||||
if hasattr(self.image_encoder, "enable_gradient_checkpointing"):
|
||||
self.image_encoder.enable_gradient_checkpointing()
|
||||
elif hasattr(self.image_encoder, 'gradient_checkpointing'):
|
||||
self.image_encoder.gradient_checkpointing = True
|
||||
|
||||
@@ -114,9 +114,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
|
||||
# UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
|
||||
# UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "ResnetBlock2D"]
|
||||
UNET_TARGET_REPLACE_MODULE = ["''UNet2DConditionModel''"]
|
||||
UNET_TARGET_REPLACE_MODULE = ["UNet2DConditionModel"]
|
||||
# UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
|
||||
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["'UNet2DConditionModel'"]
|
||||
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["UNet2DConditionModel"]
|
||||
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
|
||||
LORA_PREFIX_UNET = "lora_unet"
|
||||
LORA_PREFIX_TEXT_ENCODER = "lora_te"
|
||||
@@ -155,6 +155,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
is_lorm: bool = False,
|
||||
ignore_if_contains = None,
|
||||
parameter_threshold: float = 0.0,
|
||||
attn_only: bool = False,
|
||||
target_lin_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE,
|
||||
target_conv_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3,
|
||||
**kwargs
|
||||
@@ -243,6 +244,10 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
|
||||
# for child_name, child_module in module.named_modules():
|
||||
is_linear = module_name == 'LoRACompatibleLinear'
|
||||
is_conv2d = module_name == 'LoRACompatibleConv'
|
||||
# check if attn in name
|
||||
is_attention = "attentions" in name
|
||||
if not is_attention and attn_only:
|
||||
continue
|
||||
|
||||
if is_linear and self.lora_dim is None:
|
||||
continue
|
||||
|
||||
@@ -23,6 +23,8 @@ def get_meta_for_safetensors(meta: OrderedDict, name=None, add_software_info=Tru
|
||||
# if not float, int, bool, or str, convert to json string
|
||||
if not isinstance(value, str):
|
||||
save_meta[key] = json.dumps(value)
|
||||
# add the pt format
|
||||
save_meta["format"] = "pt"
|
||||
return save_meta
|
||||
|
||||
|
||||
|
||||
104
toolkit/models/ilora.py
Normal file
104
toolkit/models/ilora.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import weakref
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import TYPE_CHECKING
|
||||
from toolkit.models.clip_fusion import ZipperBlock
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from toolkit.lora_special import LoRAModule
|
||||
from toolkit.stable_diffusion_model import StableDiffusion
|
||||
|
||||
|
||||
class InstantLoRAMidModule(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
vision_tokens: int,
|
||||
vision_hidden_size: int,
|
||||
lora_module: 'LoRAModule',
|
||||
instant_lora_module: 'InstantLoRAModule'
|
||||
):
|
||||
super(InstantLoRAMidModule, self).__init__()
|
||||
self.dim = dim
|
||||
self.vision_tokens = vision_tokens
|
||||
self.vision_hidden_size = vision_hidden_size
|
||||
self.lora_module_ref = weakref.ref(lora_module)
|
||||
self.instant_lora_module_ref = weakref.ref(instant_lora_module)
|
||||
|
||||
self.zip = ZipperBlock(
|
||||
in_size=self.vision_hidden_size,
|
||||
in_tokens=self.vision_tokens,
|
||||
out_size=self.dim,
|
||||
out_tokens=1,
|
||||
hidden_size=self.dim,
|
||||
hidden_tokens=self.vision_tokens
|
||||
)
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
# get the vector
|
||||
img_embeds = self.instant_lora_module_ref().img_embeds
|
||||
# project it
|
||||
scaler = self.zip(img_embeds) # (batch_size, 1, dim)
|
||||
|
||||
# remove the channel dim
|
||||
scaler = scaler.squeeze(1)
|
||||
|
||||
# double up if batch is 2x the size on x (cfg)
|
||||
if x.shape[0] // 2 == scaler.shape[0]:
|
||||
scaler = torch.cat([scaler, scaler], dim=0)
|
||||
|
||||
# multiply it by the scaler
|
||||
try:
|
||||
# reshape if needed
|
||||
if len(x.shape) == 3:
|
||||
scaler = scaler.unsqueeze(1)
|
||||
x = x * scaler
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(x.shape)
|
||||
print(scaler.shape)
|
||||
raise e
|
||||
# apply tanh to limit values to -1 to 1
|
||||
scaler = torch.tanh(scaler)
|
||||
return x * scaler
|
||||
|
||||
|
||||
class InstantLoRAModule(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vision_hidden_size: int,
|
||||
vision_tokens: int,
|
||||
sd: 'StableDiffusion'
|
||||
):
|
||||
super(InstantLoRAModule, self).__init__()
|
||||
self.linear = torch.nn.Linear(2, 1)
|
||||
self.sd_ref = weakref.ref(sd)
|
||||
self.dim = sd.network.lora_dim
|
||||
self.vision_hidden_size = vision_hidden_size
|
||||
self.vision_tokens = vision_tokens
|
||||
|
||||
# stores the projection vector. Grabbed by modules
|
||||
self.img_embeds: torch.Tensor = None
|
||||
|
||||
# disable merging in. It is slower on inference
|
||||
self.sd_ref().network.can_merge_in = False
|
||||
|
||||
self.ilora_modules = torch.nn.ModuleList()
|
||||
|
||||
lora_modules = self.sd_ref().network.get_all_modules()
|
||||
|
||||
for lora_module in lora_modules:
|
||||
# add a new mid module that will take the original forward and add a vector to it
|
||||
# this will be used to add the vector to the original forward
|
||||
mid_module = InstantLoRAMidModule(self.dim, self.vision_tokens, self.vision_hidden_size, lora_module, self)
|
||||
|
||||
self.ilora_modules.append(mid_module)
|
||||
# replace the LoRA lora_mid
|
||||
lora_module.lora_mid = mid_module.forward
|
||||
|
||||
# add a new mid module that will take the original forward and add a vector to it
|
||||
# this will be used to add the vector to the original forward
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
@@ -215,12 +215,17 @@ def save_ip_adapter_from_diffusers(
|
||||
output_file: str,
|
||||
meta: 'OrderedDict',
|
||||
dtype=get_torch_dtype('fp16'),
|
||||
direct_save: bool = False
|
||||
):
|
||||
# todo: test compatibility with non diffusers
|
||||
|
||||
converted_state_dict = OrderedDict()
|
||||
for module_name, state_dict in combined_state_dict.items():
|
||||
for key, value in state_dict.items():
|
||||
converted_state_dict[f"{module_name}.{key}"] = value.detach().to('cpu', dtype=dtype)
|
||||
if direct_save:
|
||||
converted_state_dict[module_name] = state_dict.detach().to('cpu', dtype=dtype)
|
||||
else:
|
||||
for key, value in state_dict.items():
|
||||
converted_state_dict[f"{module_name}.{key}"] = value.detach().to('cpu', dtype=dtype)
|
||||
|
||||
# make sure parent folder exists
|
||||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||||
@@ -230,12 +235,15 @@ def save_ip_adapter_from_diffusers(
|
||||
def load_ip_adapter_model(
|
||||
path_to_file,
|
||||
device: Union[str] = 'cpu',
|
||||
dtype: torch.dtype = torch.float32
|
||||
dtype: torch.dtype = torch.float32,
|
||||
direct_load: bool = False
|
||||
):
|
||||
# check if it is safetensors or checkpoint
|
||||
if path_to_file.endswith('.safetensors'):
|
||||
raw_state_dict = load_file(path_to_file, device)
|
||||
combined_state_dict = OrderedDict()
|
||||
if direct_load:
|
||||
return raw_state_dict
|
||||
for combo_key, value in raw_state_dict.items():
|
||||
key_split = combo_key.split('.')
|
||||
module_name = key_split.pop(0)
|
||||
|
||||
Reference in New Issue
Block a user