Many bug fixes. Ip adapter bug fixes. Added noise to unconditional, it works better. added an ilora adapter for 1 shotting LoRAs

This commit is contained in:
Jaret Burkett
2024-01-28 08:20:03 -07:00
parent f17ad8d794
commit 92b9c71d44
10 changed files with 352 additions and 56 deletions

View File

@@ -286,9 +286,10 @@ class SDTrainer(BaseSDTrainProcess):
)
prior_loss = prior_loss * prior_mask_multiplier * self.train_config.inverted_mask_prior_multiplier
if torch.isnan(prior_loss).any():
raise ValueError("Prior loss is nan")
prior_loss = prior_loss.mean([1, 2, 3])
print("Prior loss is nan")
prior_loss = None
else:
prior_loss = prior_loss.mean([1, 2, 3])
# loss = loss + prior_loss
loss = loss.mean([1, 2, 3])
if prior_loss is not None:
@@ -992,6 +993,8 @@ class SDTrainer(BaseSDTrainProcess):
if self.adapter and isinstance(self.adapter, IPAdapter):
with self.timer('encode_adapter_embeds'):
# number of images to do if doing a quad image
quad_count = random.randint(1, 4)
image_size = self.adapter.input_size
if is_reg:
# we will zero it out in the img embedder
@@ -1004,7 +1007,8 @@ class SDTrainer(BaseSDTrainProcess):
clip_images,
drop=True,
is_training=True,
has_been_preprocessed=True
has_been_preprocessed=True,
quad_count=quad_count
)
if self.train_config.do_cfg:
unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
@@ -1014,13 +1018,15 @@ class SDTrainer(BaseSDTrainProcess):
).detach(),
is_training=True,
drop=True,
has_been_preprocessed=True
has_been_preprocessed=True,
quad_count=quad_count
)
elif has_clip_image:
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
clip_images.detach().to(self.device_torch, dtype=dtype),
is_training=True,
has_been_preprocessed=True
has_been_preprocessed=True,
quad_count=quad_count
)
if self.train_config.do_cfg:
unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
@@ -1030,7 +1036,8 @@ class SDTrainer(BaseSDTrainProcess):
).detach(),
is_training=True,
drop=True,
has_been_preprocessed=True
has_been_preprocessed=True,
quad_count=quad_count
)
else:
raise ValueError("Adapter images now must be loaded with dataloader or be a reg image")
@@ -1152,7 +1159,8 @@ class SDTrainer(BaseSDTrainProcess):
)
# check if nan
if torch.isnan(loss):
raise ValueError("loss is nan")
print("loss is nan")
loss = torch.zeros_like(loss).requires_grad_(True)
with self.timer('backward'):
# todo we have multiplier seperated. works for now as res are not in same batch, but need to change

View File

@@ -2,6 +2,7 @@ import copy
import glob
import inspect
import json
import random
import shutil
from collections import OrderedDict
import os
@@ -423,7 +424,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
adapter_name += '_t2i'
elif self.adapter_config.type == 'clip':
adapter_name += '_clip'
elif self.adapter_config.type == 'ip':
elif self.adapter_config.type.startswith('ip'):
adapter_name += '_ip'
else:
adapter_name += '_adapter'
@@ -444,7 +445,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
state_dict,
output_file=file_path,
meta=save_meta,
dtype=get_torch_dtype(self.save_config.dtype)
dtype=get_torch_dtype(self.save_config.dtype),
direct_save=self.adapter_config.train_only_image_encoder
)
else:
if self.save_config.save_format == "diffusers":
@@ -1010,7 +1012,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
loaded_state_dict = load_ip_adapter_model(
latest_save_path,
self.device,
dtype=dtype
dtype=dtype,
direct_load=self.adapter_config.train_only_image_encoder
)
self.adapter.load_state_dict(loaded_state_dict)
else:
@@ -1146,14 +1149,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.snr_gos.scale.data = torch.tensor(json_data['scale'], device=self.device_torch)
self.snr_gos.gamma.data = torch.tensor(json_data['gamma'], device=self.device_torch)
# load the adapters before the dataset as they may use the clip encoders
if self.adapter_config is not None:
self.setup_adapter()
flush()
if not self.is_fine_tuning:
if self.network_config is not None:
# TODO should we completely switch to LycorisSpecialNetwork?
network_kwargs = {}
network_kwargs = self.network_config.network_kwargs
is_lycoris = False
is_lorm = self.network_config.type.lower() == 'lorm'
# default to LoCON if there are any conv layers or if it is named
@@ -1279,12 +1279,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
flush()
if self.adapter_config is not None:
# self.setup_adapter()
# set trainable params
params.append({
'params': self.adapter.parameters(),
'lr': self.train_config.adapter_lr
})
self.setup_adapter()
if self.adapter_config.train:
# set trainable params
params.append({
'params': self.adapter.parameters(),
'lr': self.train_config.adapter_lr
})
if self.train_config.gradient_checkpointing:
self.adapter.enable_gradient_checkpointing()
flush()
params = self.load_additional_training_modules(params)
@@ -1306,7 +1310,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
refiner_lr=self.train_config.refiner_lr,
)
# we may be using it for prompt injections
if self.adapter_config is not None:
if self.adapter_config is not None and self.adapter is None:
self.setup_adapter()
flush()
### HOOK ###
@@ -1379,7 +1383,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
# sample first
if self.train_config.skip_first_sample:
self.print("Skipping first sample due to config setting")
elif self.step_num <= 1:
elif self.step_num <= 1 or self.train_config.force_first_sample:
self.print("Generating baseline samples before training")
self.sample(self.step_num)
@@ -1422,6 +1426,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
start_step_num = self.step_num
did_first_flush = False
for step in range(start_step_num, self.train_config.steps):
if self.train_config.do_random_cfg:
self.train_config.do_cfg = True
self.train_config.cfg_scale = value_map(random.random(), 0, 1, 1.0, self.train_config.max_cfg_scale)
self.step_num = step
# default to true so various things can turn it off
self.is_grad_accumulation_step = True

View File

@@ -113,6 +113,7 @@ class NetworkConfig:
self.linear_alpha: float = kwargs.get('linear_alpha', self.alpha)
self.conv_alpha: float = kwargs.get('conv_alpha', self.conv)
self.dropout: Union[float, None] = kwargs.get('dropout', None)
self.network_kwargs: dict = kwargs.get('network_kwargs', {})
self.lorm_config: Union[LoRMConfig, None] = None
lorm = kwargs.get('lorm', None)
@@ -153,10 +154,14 @@ class AdapterConfig:
self.num_tokens: int = num_tokens
self.train_image_encoder: bool = kwargs.get('train_image_encoder', False)
self.train_only_image_encoder: bool = kwargs.get('train_only_image_encoder', False)
if self.train_only_image_encoder:
self.train_image_encoder = True
self.image_encoder_arch: str = kwargs.get('image_encoder_arch', 'clip') # clip vit vit_hybrid, safe
self.safe_reducer_channels: int = kwargs.get('safe_reducer_channels', 512)
self.safe_channels: int = kwargs.get('safe_channels', 2048)
self.safe_tokens: int = kwargs.get('safe_tokens', 8)
self.quad_image: bool = kwargs.get('quad_image', False)
# clip vision
self.trigger = kwargs.get('trigger', 'tri993r')
@@ -211,6 +216,7 @@ class TrainConfig:
self.learnable_snr_gos = kwargs.get('learnable_snr_gos', False)
self.noise_offset = kwargs.get('noise_offset', 0.0)
self.skip_first_sample = kwargs.get('skip_first_sample', False)
self.force_first_sample = kwargs.get('force_first_sample', False)
self.gradient_checkpointing = kwargs.get('gradient_checkpointing', True)
self.weight_jitter = kwargs.get('weight_jitter', 0.0)
self.merge_network_on_save = kwargs.get('merge_network_on_save', False)
@@ -275,7 +281,9 @@ class TrainConfig:
self.dynamic_noise_offset = kwargs.get('dynamic_noise_offset', False)
self.do_cfg = kwargs.get('do_cfg', False)
self.do_random_cfg = kwargs.get('do_random_cfg', False)
self.cfg_scale = kwargs.get('cfg_scale', 1.0)
self.max_cfg_scale = kwargs.get('max_cfg_scale', self.cfg_scale)
class ModelConfig:

View File

@@ -7,6 +7,7 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from toolkit.models.clip_fusion import CLIPFusionModule
from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
from toolkit.models.ilora import InstantLoRAModule
from toolkit.paths import REPOS_ROOT
from toolkit.photomaker import PhotoMakerIDEncoder, FuseModule, PhotoMakerCLIPEncoder
from toolkit.saving import load_ip_adapter_model
@@ -74,6 +75,7 @@ class CustomAdapter(torch.nn.Module):
self.clip_image_processor = self.image_processor
self.clip_fusion_module: CLIPFusionModule = None
self.ilora_module: InstantLoRAModule = None
self.setup_adapter()
@@ -106,6 +108,15 @@ class CustomAdapter(torch.nn.Module):
vision_hidden_size=self.vision_encoder.config.hidden_size,
vision_tokens=vision_tokens
)
elif self.adapter_type == 'ilora':
vision_tokens = ((self.vision_encoder.config.image_size // self.vision_encoder.config.patch_size) ** 2)
if self.config.image_encoder_arch == 'clip':
vision_tokens = vision_tokens + 1
self.ilora_module = InstantLoRAModule(
vision_tokens=vision_tokens,
vision_hidden_size=self.vision_encoder.config.hidden_size,
sd=self.sd_ref()
)
else:
raise ValueError(f"unknown adapter type: {self.adapter_type}")
@@ -283,6 +294,9 @@ class CustomAdapter(torch.nn.Module):
if 'fuse_module' in state_dict:
self.fuse_module.load_state_dict(state_dict['fuse_module'], strict=strict)
if 'ilora' in state_dict:
self.ilora_module.load_state_dict(state_dict['ilora'], strict=strict)
pass
def state_dict(self) -> OrderedDict:
@@ -301,6 +315,11 @@ class CustomAdapter(torch.nn.Module):
state_dict["vision_encoder"] = self.vision_encoder.state_dict()
state_dict["clip_fusion"] = self.clip_fusion_module.state_dict()
return state_dict
elif self.adapter_type == 'ilora':
if self.config.train_image_encoder:
state_dict["vision_encoder"] = self.vision_encoder.state_dict()
state_dict["ilora"] = self.ilora_module.state_dict()
return state_dict
else:
raise NotImplementedError
@@ -309,7 +328,7 @@ class CustomAdapter(torch.nn.Module):
prompt: Union[List[str], str],
is_unconditional: bool = False,
):
if self.adapter_type == 'clip_fusion':
if self.adapter_type == 'clip_fusion' or self.adapter_type == 'ilora':
return prompt
elif self.adapter_type == 'photo_maker':
if is_unconditional:
@@ -408,7 +427,7 @@ class CustomAdapter(torch.nn.Module):
has_been_preprocessed=False,
is_unconditional=False
) -> PromptEmbeds:
if self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion':
if self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion' or self.adapter_type == 'ilora':
if is_unconditional:
# we dont condition the negative embeds for photo maker
return prompt_embeds.clone()
@@ -459,7 +478,7 @@ class CustomAdapter(torch.nn.Module):
self.token_mask
)
return prompt_embeds
elif self.adapter_type == 'clip_fusion':
elif self.adapter_type == 'clip_fusion' or self.adapter_type == 'ilora':
with torch.set_grad_enabled(is_training):
if is_training and self.config.train_image_encoder:
self.vision_encoder.train()
@@ -480,11 +499,17 @@ class CustomAdapter(torch.nn.Module):
if not is_training or not self.config.train_image_encoder:
img_embeds = img_embeds.detach()
prompt_embeds.text_embeds = self.clip_fusion_module(
prompt_embeds.text_embeds,
img_embeds
)
return prompt_embeds
if self.adapter_type == 'ilora':
self.ilora_module.img_embeds = img_embeds
return prompt_embeds
else:
prompt_embeds.text_embeds = self.clip_fusion_module(
prompt_embeds.text_embeds,
img_embeds
)
return prompt_embeds
else:
@@ -499,5 +524,9 @@ class CustomAdapter(torch.nn.Module):
yield from self.clip_fusion_module.parameters(recurse)
if self.config.train_image_encoder:
yield from self.vision_encoder.parameters(recurse)
elif self.config.type == 'ilora':
yield from self.ilora_module.parameters(recurse)
if self.config.train_image_encoder:
yield from self.vision_encoder.parameters(recurse)
else:
raise NotImplementedError

View File

@@ -655,6 +655,21 @@ class ClipImageFileItemDTOMixin:
else:
self.clip_image_tensor = transforms.ToTensor()(img)
# random crop
# if self.dataset_config.clip_image_random_crop:
# # crop up to 20% on all sides. Keep is square
# crop_percent = random.randint(0, 20) / 100
# crop_width = int(self.clip_image_tensor.shape[2] * crop_percent)
# crop_height = int(self.clip_image_tensor.shape[1] * crop_percent)
# crop_left = random.randint(0, crop_width)
# crop_top = random.randint(0, crop_height)
# crop_right = self.clip_image_tensor.shape[2] - crop_width - crop_left
# crop_bottom = self.clip_image_tensor.shape[1] - crop_height - crop_top
# if len(self.clip_image_tensor.shape) == 3:
# self.clip_image_tensor = self.clip_image_tensor[:, crop_top:-crop_bottom, crop_left:-crop_right]
# elif len(self.clip_image_tensor.shape) == 4:
# self.clip_image_tensor = self.clip_image_tensor[:, :, crop_top:-crop_bottom, crop_left:-crop_right]
if self.clip_image_processor is not None:
# run it
tensors_0_1 = self.clip_image_tensor.to(dtype=torch.float16)

View File

@@ -1,3 +1,5 @@
import random
import torch
import sys
@@ -39,6 +41,9 @@ from transformers import ViTHybridImageProcessor, ViTHybridForImageClassificatio
from transformers import ViTFeatureExtractor, ViTForImageClassification
# gradient checkpointing
from torch.utils.checkpoint import checkpoint
import torch.nn.functional as F
@@ -166,6 +171,8 @@ class IPAdapter(torch.nn.Module):
self.device = self.sd_ref().unet.device
self.preprocessor: Optional[CLIPImagePreProcessor] = None
self.input_size = 224
self.clip_noise_zero = True
self.unconditional: torch.Tensor = None
if self.config.image_encoder_arch == 'clip' or self.config.image_encoder_arch == 'clip+':
try:
self.clip_image_processor = CLIPImageProcessor.from_pretrained(adapter_config.image_encoder_path)
@@ -236,6 +243,16 @@ class IPAdapter(torch.nn.Module):
self.input_size = self.image_encoder.config.image_size
if self.config.quad_image: # 4x4 image
# self.clip_image_processor.config
# We do a 3x downscale of the image, so we need to adjust the input size
preprocessor_input_size = self.image_encoder.config.image_size * 2
# update the preprocessor so images come in at the right size
self.clip_image_processor.size['shortest_edge'] = preprocessor_input_size
self.clip_image_processor.crop_size['height'] = preprocessor_input_size
self.clip_image_processor.crop_size['width'] = preprocessor_input_size
if self.config.image_encoder_arch == 'clip+':
# self.clip_image_processor.config
# We do a 3x downscale of the image, so we need to adjust the input size
@@ -349,6 +366,15 @@ class IPAdapter(torch.nn.Module):
self.image_encoder.train()
self.image_encoder.requires_grad_(True)
# premake a unconditional
zerod = torch.zeros(1, 3, self.input_size, self.input_size, device=self.device, dtype=torch.float16)
self.unconditional = self.clip_image_processor(
images=zerod,
return_tensors="pt",
do_resize=True,
do_rescale=False,
).pixel_values
def to(self, *args, **kwargs):
super().to(*args, **kwargs)
self.image_encoder.to(*args, **kwargs)
@@ -358,20 +384,23 @@ class IPAdapter(torch.nn.Module):
self.preprocessor.to(*args, **kwargs)
return self
def load_ip_adapter(self, state_dict: Union[OrderedDict, dict]):
self.image_proj_model.load_state_dict(state_dict["image_proj"])
ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
ip_layers.load_state_dict(state_dict["ip_adapter"])
if self.config.train_image_encoder and 'image_encoder' in state_dict:
self.image_encoder.load_state_dict(state_dict["image_encoder"])
if self.preprocessor is not None and 'preprocessor' in state_dict:
self.preprocessor.load_state_dict(state_dict["preprocessor"])
# def load_ip_adapter(self, state_dict: Union[OrderedDict, dict]):
# self.image_proj_model.load_state_dict(state_dict["image_proj"])
# ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
# ip_layers.load_state_dict(state_dict["ip_adapter"])
# if self.config.train_image_encoder and 'image_encoder' in state_dict:
# self.image_encoder.load_state_dict(state_dict["image_encoder"])
# if self.preprocessor is not None and 'preprocessor' in state_dict:
# self.preprocessor.load_state_dict(state_dict["preprocessor"])
# def load_state_dict(self, state_dict: Union[OrderedDict, dict]):
# self.load_ip_adapter(state_dict)
def state_dict(self) -> OrderedDict:
state_dict = OrderedDict()
if self.config.train_only_image_encoder:
return self.image_encoder.state_dict()
state_dict["image_proj"] = self.image_proj_model.state_dict()
state_dict["ip_adapter"] = self.adapter_modules.state_dict()
if self.config.train_image_encoder:
@@ -402,13 +431,28 @@ class IPAdapter(torch.nn.Module):
# clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
# return clip_image_embeds
def to(self, *args, **kwargs):
super().to(*args, **kwargs)
self.image_encoder.to(*args, **kwargs)
self.image_proj_model.to(*args, **kwargs)
self.adapter_modules.to(*args, **kwargs)
if self.preprocessor is not None:
self.preprocessor.to(*args, **kwargs)
return self
def get_clip_image_embeds_from_tensors(
self,
tensors_0_1: torch.Tensor,
drop=False,
is_training=False,
has_been_preprocessed=False
has_been_preprocessed=False,
quad_count=4,
) -> torch.Tensor:
if self.sd_ref().unet.device != self.device:
self.to(self.sd_ref().unet.device)
if self.sd_ref().unet.device != self.image_encoder.device:
self.to(self.sd_ref().unet.device)
if not self.config.train:
is_training = False
with torch.no_grad():
# on training the clip image is created in the dataloader
if not has_been_preprocessed:
@@ -417,11 +461,19 @@ class IPAdapter(torch.nn.Module):
tensors_0_1 = tensors_0_1.unsqueeze(0)
# training tensors are 0 - 1
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
# if images are out of this range throw error
if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
tensors_0_1.min(), tensors_0_1.max()
))
# unconditional
if drop:
if self.clip_noise_zero:
tensors_0_1 = torch.rand_like(tensors_0_1).detach()
else:
tensors_0_1 = torch.zeros_like(tensors_0_1).detach()
# tensors_0_1 = tensors_0_1 * 0
clip_image = self.clip_image_processor(
images=tensors_0_1,
return_tensors="pt",
@@ -429,10 +481,42 @@ class IPAdapter(torch.nn.Module):
do_rescale=False,
).pixel_values
else:
clip_image = tensors_0_1
if drop:
# scale the noise down
if self.clip_noise_zero:
tensors_0_1 = torch.rand_like(tensors_0_1).detach()
else:
tensors_0_1 = torch.zeros_like(tensors_0_1).detach()
# tensors_0_1 = tensors_0_1 * 0
mean = torch.tensor(self.clip_image_processor.image_mean).to(
self.device, dtype=get_torch_dtype(self.sd_ref().dtype)
).detach()
std = torch.tensor(self.clip_image_processor.image_std).to(
self.device, dtype=get_torch_dtype(self.sd_ref().dtype)
).detach()
tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0
clip_image = (tensors_0_1 - mean.view([1, 3, 1, 1])) / std.view([1, 3, 1, 1])
else:
clip_image = tensors_0_1
clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach()
if drop:
clip_image = clip_image * 0
if self.config.quad_image:
# split the 4x4 grid and stack on batch
ci1, ci2 = clip_image.chunk(2, dim=2)
ci1, ci3 = ci1.chunk(2, dim=3)
ci2, ci4 = ci2.chunk(2, dim=3)
to_cat = []
for i, ci in enumerate([ci1, ci2, ci3, ci4]):
if i < quad_count:
to_cat.append(ci)
else:
break
clip_image = torch.cat(to_cat, dim=0).detach()
# if drop:
# clip_image = clip_image * 0
with torch.set_grad_enabled(is_training):
if is_training:
self.image_encoder.train()
@@ -457,6 +541,20 @@ class IPAdapter(torch.nn.Module):
clip_image_embeds = clip_output.hidden_states[-2]
else:
clip_image_embeds = clip_output.image_embeds
if self.config.quad_image:
# get the outputs of the quat
chunks = clip_image_embeds.chunk(quad_count, dim=0)
chunk_sum = torch.zeros_like(chunks[0])
for chunk in chunks:
chunk_sum = chunk_sum + chunk
# get the mean of them
clip_image_embeds = chunk_sum / quad_count
if not is_training:
clip_image_embeds = clip_image_embeds.detach()
return clip_image_embeds
# use drop for prompt dropout, or negatives
@@ -467,6 +565,9 @@ class IPAdapter(torch.nn.Module):
return embeddings
def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
if self.config.train_only_image_encoder:
yield from self.image_encoder.parameters(recurse)
return
for attn_processor in self.adapter_modules:
yield from attn_processor.parameters(recurse)
yield from self.image_proj_model.parameters(recurse)
@@ -561,17 +662,26 @@ class IPAdapter(torch.nn.Module):
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
strict = False
try:
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict)
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict)
except Exception as e:
print(e)
print("could not load ip adapter weights, trying to merge in weights")
self.merge_in_weights(state_dict)
if 'ip_adapter' in state_dict:
try:
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict)
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict)
except Exception as e:
print(e)
print("could not load ip adapter weights, trying to merge in weights")
self.merge_in_weights(state_dict)
if self.config.train_image_encoder and 'image_encoder' in state_dict:
self.image_encoder.load_state_dict(state_dict["image_encoder"], strict=strict)
if self.preprocessor is not None and 'preprocessor' in state_dict:
self.preprocessor.load_state_dict(state_dict["preprocessor"], strict=strict)
if self.config.train_only_image_encoder and 'ip_adapter' not in state_dict:
# we are loading pure clip weights.
self.image_encoder.load_state_dict(state_dict, strict=strict)
def enable_gradient_checkpointing(self):
self.image_encoder.gradient_checkpointing = True
if hasattr(self.image_encoder, "enable_gradient_checkpointing"):
self.image_encoder.enable_gradient_checkpointing()
elif hasattr(self.image_encoder, 'gradient_checkpointing'):
self.image_encoder.gradient_checkpointing = True

View File

@@ -114,9 +114,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
# UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
# UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "ResnetBlock2D"]
UNET_TARGET_REPLACE_MODULE = ["''UNet2DConditionModel''"]
UNET_TARGET_REPLACE_MODULE = ["UNet2DConditionModel"]
# UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["'UNet2DConditionModel'"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["UNet2DConditionModel"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
LORA_PREFIX_UNET = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te"
@@ -155,6 +155,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
is_lorm: bool = False,
ignore_if_contains = None,
parameter_threshold: float = 0.0,
attn_only: bool = False,
target_lin_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE,
target_conv_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3,
**kwargs
@@ -243,6 +244,10 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
# for child_name, child_module in module.named_modules():
is_linear = module_name == 'LoRACompatibleLinear'
is_conv2d = module_name == 'LoRACompatibleConv'
# check if attn in name
is_attention = "attentions" in name
if not is_attention and attn_only:
continue
if is_linear and self.lora_dim is None:
continue

View File

@@ -23,6 +23,8 @@ def get_meta_for_safetensors(meta: OrderedDict, name=None, add_software_info=Tru
# if not float, int, bool, or str, convert to json string
if not isinstance(value, str):
save_meta[key] = json.dumps(value)
# add the pt format
save_meta["format"] = "pt"
return save_meta

104
toolkit/models/ilora.py Normal file
View File

@@ -0,0 +1,104 @@
import weakref
import torch
import torch.nn as nn
from typing import TYPE_CHECKING
from toolkit.models.clip_fusion import ZipperBlock
if TYPE_CHECKING:
from toolkit.lora_special import LoRAModule
from toolkit.stable_diffusion_model import StableDiffusion
class InstantLoRAMidModule(torch.nn.Module):
def __init__(
self,
dim: int,
vision_tokens: int,
vision_hidden_size: int,
lora_module: 'LoRAModule',
instant_lora_module: 'InstantLoRAModule'
):
super(InstantLoRAMidModule, self).__init__()
self.dim = dim
self.vision_tokens = vision_tokens
self.vision_hidden_size = vision_hidden_size
self.lora_module_ref = weakref.ref(lora_module)
self.instant_lora_module_ref = weakref.ref(instant_lora_module)
self.zip = ZipperBlock(
in_size=self.vision_hidden_size,
in_tokens=self.vision_tokens,
out_size=self.dim,
out_tokens=1,
hidden_size=self.dim,
hidden_tokens=self.vision_tokens
)
def forward(self, x, *args, **kwargs):
# get the vector
img_embeds = self.instant_lora_module_ref().img_embeds
# project it
scaler = self.zip(img_embeds) # (batch_size, 1, dim)
# remove the channel dim
scaler = scaler.squeeze(1)
# double up if batch is 2x the size on x (cfg)
if x.shape[0] // 2 == scaler.shape[0]:
scaler = torch.cat([scaler, scaler], dim=0)
# multiply it by the scaler
try:
# reshape if needed
if len(x.shape) == 3:
scaler = scaler.unsqueeze(1)
x = x * scaler
except Exception as e:
print(e)
print(x.shape)
print(scaler.shape)
raise e
# apply tanh to limit values to -1 to 1
scaler = torch.tanh(scaler)
return x * scaler
class InstantLoRAModule(torch.nn.Module):
def __init__(
self,
vision_hidden_size: int,
vision_tokens: int,
sd: 'StableDiffusion'
):
super(InstantLoRAModule, self).__init__()
self.linear = torch.nn.Linear(2, 1)
self.sd_ref = weakref.ref(sd)
self.dim = sd.network.lora_dim
self.vision_hidden_size = vision_hidden_size
self.vision_tokens = vision_tokens
# stores the projection vector. Grabbed by modules
self.img_embeds: torch.Tensor = None
# disable merging in. It is slower on inference
self.sd_ref().network.can_merge_in = False
self.ilora_modules = torch.nn.ModuleList()
lora_modules = self.sd_ref().network.get_all_modules()
for lora_module in lora_modules:
# add a new mid module that will take the original forward and add a vector to it
# this will be used to add the vector to the original forward
mid_module = InstantLoRAMidModule(self.dim, self.vision_tokens, self.vision_hidden_size, lora_module, self)
self.ilora_modules.append(mid_module)
# replace the LoRA lora_mid
lora_module.lora_mid = mid_module.forward
# add a new mid module that will take the original forward and add a vector to it
# this will be used to add the vector to the original forward
def forward(self, x):
return self.linear(x)

View File

@@ -215,12 +215,17 @@ def save_ip_adapter_from_diffusers(
output_file: str,
meta: 'OrderedDict',
dtype=get_torch_dtype('fp16'),
direct_save: bool = False
):
# todo: test compatibility with non diffusers
converted_state_dict = OrderedDict()
for module_name, state_dict in combined_state_dict.items():
for key, value in state_dict.items():
converted_state_dict[f"{module_name}.{key}"] = value.detach().to('cpu', dtype=dtype)
if direct_save:
converted_state_dict[module_name] = state_dict.detach().to('cpu', dtype=dtype)
else:
for key, value in state_dict.items():
converted_state_dict[f"{module_name}.{key}"] = value.detach().to('cpu', dtype=dtype)
# make sure parent folder exists
os.makedirs(os.path.dirname(output_file), exist_ok=True)
@@ -230,12 +235,15 @@ def save_ip_adapter_from_diffusers(
def load_ip_adapter_model(
path_to_file,
device: Union[str] = 'cpu',
dtype: torch.dtype = torch.float32
dtype: torch.dtype = torch.float32,
direct_load: bool = False
):
# check if it is safetensors or checkpoint
if path_to_file.endswith('.safetensors'):
raw_state_dict = load_file(path_to_file, device)
combined_state_dict = OrderedDict()
if direct_load:
return raw_state_dict
for combo_key, value in raw_state_dict.items():
key_split = combo_key.split('.')
module_name = key_split.pop(0)