Many bug fixes. Ip adapter bug fixes. Added noise to unconditional, it works better. added an ilora adapter for 1 shotting LoRAs

This commit is contained in:
Jaret Burkett
2024-01-28 08:20:03 -07:00
parent f17ad8d794
commit 92b9c71d44
10 changed files with 352 additions and 56 deletions

View File

@@ -286,9 +286,10 @@ class SDTrainer(BaseSDTrainProcess):
) )
prior_loss = prior_loss * prior_mask_multiplier * self.train_config.inverted_mask_prior_multiplier prior_loss = prior_loss * prior_mask_multiplier * self.train_config.inverted_mask_prior_multiplier
if torch.isnan(prior_loss).any(): if torch.isnan(prior_loss).any():
raise ValueError("Prior loss is nan") print("Prior loss is nan")
prior_loss = None
prior_loss = prior_loss.mean([1, 2, 3]) else:
prior_loss = prior_loss.mean([1, 2, 3])
# loss = loss + prior_loss # loss = loss + prior_loss
loss = loss.mean([1, 2, 3]) loss = loss.mean([1, 2, 3])
if prior_loss is not None: if prior_loss is not None:
@@ -992,6 +993,8 @@ class SDTrainer(BaseSDTrainProcess):
if self.adapter and isinstance(self.adapter, IPAdapter): if self.adapter and isinstance(self.adapter, IPAdapter):
with self.timer('encode_adapter_embeds'): with self.timer('encode_adapter_embeds'):
# number of images to do if doing a quad image
quad_count = random.randint(1, 4)
image_size = self.adapter.input_size image_size = self.adapter.input_size
if is_reg: if is_reg:
# we will zero it out in the img embedder # we will zero it out in the img embedder
@@ -1004,7 +1007,8 @@ class SDTrainer(BaseSDTrainProcess):
clip_images, clip_images,
drop=True, drop=True,
is_training=True, is_training=True,
has_been_preprocessed=True has_been_preprocessed=True,
quad_count=quad_count
) )
if self.train_config.do_cfg: if self.train_config.do_cfg:
unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
@@ -1014,13 +1018,15 @@ class SDTrainer(BaseSDTrainProcess):
).detach(), ).detach(),
is_training=True, is_training=True,
drop=True, drop=True,
has_been_preprocessed=True has_been_preprocessed=True,
quad_count=quad_count
) )
elif has_clip_image: elif has_clip_image:
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
clip_images.detach().to(self.device_torch, dtype=dtype), clip_images.detach().to(self.device_torch, dtype=dtype),
is_training=True, is_training=True,
has_been_preprocessed=True has_been_preprocessed=True,
quad_count=quad_count
) )
if self.train_config.do_cfg: if self.train_config.do_cfg:
unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
@@ -1030,7 +1036,8 @@ class SDTrainer(BaseSDTrainProcess):
).detach(), ).detach(),
is_training=True, is_training=True,
drop=True, drop=True,
has_been_preprocessed=True has_been_preprocessed=True,
quad_count=quad_count
) )
else: else:
raise ValueError("Adapter images now must be loaded with dataloader or be a reg image") raise ValueError("Adapter images now must be loaded with dataloader or be a reg image")
@@ -1152,7 +1159,8 @@ class SDTrainer(BaseSDTrainProcess):
) )
# check if nan # check if nan
if torch.isnan(loss): if torch.isnan(loss):
raise ValueError("loss is nan") print("loss is nan")
loss = torch.zeros_like(loss).requires_grad_(True)
with self.timer('backward'): with self.timer('backward'):
# todo we have multiplier seperated. works for now as res are not in same batch, but need to change # todo we have multiplier seperated. works for now as res are not in same batch, but need to change

View File

@@ -2,6 +2,7 @@ import copy
import glob import glob
import inspect import inspect
import json import json
import random
import shutil import shutil
from collections import OrderedDict from collections import OrderedDict
import os import os
@@ -423,7 +424,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
adapter_name += '_t2i' adapter_name += '_t2i'
elif self.adapter_config.type == 'clip': elif self.adapter_config.type == 'clip':
adapter_name += '_clip' adapter_name += '_clip'
elif self.adapter_config.type == 'ip': elif self.adapter_config.type.startswith('ip'):
adapter_name += '_ip' adapter_name += '_ip'
else: else:
adapter_name += '_adapter' adapter_name += '_adapter'
@@ -444,7 +445,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
state_dict, state_dict,
output_file=file_path, output_file=file_path,
meta=save_meta, meta=save_meta,
dtype=get_torch_dtype(self.save_config.dtype) dtype=get_torch_dtype(self.save_config.dtype),
direct_save=self.adapter_config.train_only_image_encoder
) )
else: else:
if self.save_config.save_format == "diffusers": if self.save_config.save_format == "diffusers":
@@ -1010,7 +1012,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
loaded_state_dict = load_ip_adapter_model( loaded_state_dict = load_ip_adapter_model(
latest_save_path, latest_save_path,
self.device, self.device,
dtype=dtype dtype=dtype,
direct_load=self.adapter_config.train_only_image_encoder
) )
self.adapter.load_state_dict(loaded_state_dict) self.adapter.load_state_dict(loaded_state_dict)
else: else:
@@ -1146,14 +1149,11 @@ class BaseSDTrainProcess(BaseTrainProcess):
self.snr_gos.scale.data = torch.tensor(json_data['scale'], device=self.device_torch) self.snr_gos.scale.data = torch.tensor(json_data['scale'], device=self.device_torch)
self.snr_gos.gamma.data = torch.tensor(json_data['gamma'], device=self.device_torch) self.snr_gos.gamma.data = torch.tensor(json_data['gamma'], device=self.device_torch)
# load the adapters before the dataset as they may use the clip encoders
if self.adapter_config is not None:
self.setup_adapter()
flush() flush()
if not self.is_fine_tuning: if not self.is_fine_tuning:
if self.network_config is not None: if self.network_config is not None:
# TODO should we completely switch to LycorisSpecialNetwork? # TODO should we completely switch to LycorisSpecialNetwork?
network_kwargs = {} network_kwargs = self.network_config.network_kwargs
is_lycoris = False is_lycoris = False
is_lorm = self.network_config.type.lower() == 'lorm' is_lorm = self.network_config.type.lower() == 'lorm'
# default to LoCON if there are any conv layers or if it is named # default to LoCON if there are any conv layers or if it is named
@@ -1279,12 +1279,16 @@ class BaseSDTrainProcess(BaseTrainProcess):
flush() flush()
if self.adapter_config is not None: if self.adapter_config is not None:
# self.setup_adapter() self.setup_adapter()
# set trainable params if self.adapter_config.train:
params.append({ # set trainable params
'params': self.adapter.parameters(), params.append({
'lr': self.train_config.adapter_lr 'params': self.adapter.parameters(),
}) 'lr': self.train_config.adapter_lr
})
if self.train_config.gradient_checkpointing:
self.adapter.enable_gradient_checkpointing()
flush() flush()
params = self.load_additional_training_modules(params) params = self.load_additional_training_modules(params)
@@ -1306,7 +1310,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
refiner_lr=self.train_config.refiner_lr, refiner_lr=self.train_config.refiner_lr,
) )
# we may be using it for prompt injections # we may be using it for prompt injections
if self.adapter_config is not None: if self.adapter_config is not None and self.adapter is None:
self.setup_adapter() self.setup_adapter()
flush() flush()
### HOOK ### ### HOOK ###
@@ -1379,7 +1383,7 @@ class BaseSDTrainProcess(BaseTrainProcess):
# sample first # sample first
if self.train_config.skip_first_sample: if self.train_config.skip_first_sample:
self.print("Skipping first sample due to config setting") self.print("Skipping first sample due to config setting")
elif self.step_num <= 1: elif self.step_num <= 1 or self.train_config.force_first_sample:
self.print("Generating baseline samples before training") self.print("Generating baseline samples before training")
self.sample(self.step_num) self.sample(self.step_num)
@@ -1422,6 +1426,9 @@ class BaseSDTrainProcess(BaseTrainProcess):
start_step_num = self.step_num start_step_num = self.step_num
did_first_flush = False did_first_flush = False
for step in range(start_step_num, self.train_config.steps): for step in range(start_step_num, self.train_config.steps):
if self.train_config.do_random_cfg:
self.train_config.do_cfg = True
self.train_config.cfg_scale = value_map(random.random(), 0, 1, 1.0, self.train_config.max_cfg_scale)
self.step_num = step self.step_num = step
# default to true so various things can turn it off # default to true so various things can turn it off
self.is_grad_accumulation_step = True self.is_grad_accumulation_step = True

View File

@@ -113,6 +113,7 @@ class NetworkConfig:
self.linear_alpha: float = kwargs.get('linear_alpha', self.alpha) self.linear_alpha: float = kwargs.get('linear_alpha', self.alpha)
self.conv_alpha: float = kwargs.get('conv_alpha', self.conv) self.conv_alpha: float = kwargs.get('conv_alpha', self.conv)
self.dropout: Union[float, None] = kwargs.get('dropout', None) self.dropout: Union[float, None] = kwargs.get('dropout', None)
self.network_kwargs: dict = kwargs.get('network_kwargs', {})
self.lorm_config: Union[LoRMConfig, None] = None self.lorm_config: Union[LoRMConfig, None] = None
lorm = kwargs.get('lorm', None) lorm = kwargs.get('lorm', None)
@@ -153,10 +154,14 @@ class AdapterConfig:
self.num_tokens: int = num_tokens self.num_tokens: int = num_tokens
self.train_image_encoder: bool = kwargs.get('train_image_encoder', False) self.train_image_encoder: bool = kwargs.get('train_image_encoder', False)
self.train_only_image_encoder: bool = kwargs.get('train_only_image_encoder', False)
if self.train_only_image_encoder:
self.train_image_encoder = True
self.image_encoder_arch: str = kwargs.get('image_encoder_arch', 'clip') # clip vit vit_hybrid, safe self.image_encoder_arch: str = kwargs.get('image_encoder_arch', 'clip') # clip vit vit_hybrid, safe
self.safe_reducer_channels: int = kwargs.get('safe_reducer_channels', 512) self.safe_reducer_channels: int = kwargs.get('safe_reducer_channels', 512)
self.safe_channels: int = kwargs.get('safe_channels', 2048) self.safe_channels: int = kwargs.get('safe_channels', 2048)
self.safe_tokens: int = kwargs.get('safe_tokens', 8) self.safe_tokens: int = kwargs.get('safe_tokens', 8)
self.quad_image: bool = kwargs.get('quad_image', False)
# clip vision # clip vision
self.trigger = kwargs.get('trigger', 'tri993r') self.trigger = kwargs.get('trigger', 'tri993r')
@@ -211,6 +216,7 @@ class TrainConfig:
self.learnable_snr_gos = kwargs.get('learnable_snr_gos', False) self.learnable_snr_gos = kwargs.get('learnable_snr_gos', False)
self.noise_offset = kwargs.get('noise_offset', 0.0) self.noise_offset = kwargs.get('noise_offset', 0.0)
self.skip_first_sample = kwargs.get('skip_first_sample', False) self.skip_first_sample = kwargs.get('skip_first_sample', False)
self.force_first_sample = kwargs.get('force_first_sample', False)
self.gradient_checkpointing = kwargs.get('gradient_checkpointing', True) self.gradient_checkpointing = kwargs.get('gradient_checkpointing', True)
self.weight_jitter = kwargs.get('weight_jitter', 0.0) self.weight_jitter = kwargs.get('weight_jitter', 0.0)
self.merge_network_on_save = kwargs.get('merge_network_on_save', False) self.merge_network_on_save = kwargs.get('merge_network_on_save', False)
@@ -275,7 +281,9 @@ class TrainConfig:
self.dynamic_noise_offset = kwargs.get('dynamic_noise_offset', False) self.dynamic_noise_offset = kwargs.get('dynamic_noise_offset', False)
self.do_cfg = kwargs.get('do_cfg', False) self.do_cfg = kwargs.get('do_cfg', False)
self.do_random_cfg = kwargs.get('do_random_cfg', False)
self.cfg_scale = kwargs.get('cfg_scale', 1.0) self.cfg_scale = kwargs.get('cfg_scale', 1.0)
self.max_cfg_scale = kwargs.get('max_cfg_scale', self.cfg_scale)
class ModelConfig: class ModelConfig:

View File

@@ -7,6 +7,7 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from toolkit.models.clip_fusion import CLIPFusionModule from toolkit.models.clip_fusion import CLIPFusionModule
from toolkit.models.clip_pre_processor import CLIPImagePreProcessor from toolkit.models.clip_pre_processor import CLIPImagePreProcessor
from toolkit.models.ilora import InstantLoRAModule
from toolkit.paths import REPOS_ROOT from toolkit.paths import REPOS_ROOT
from toolkit.photomaker import PhotoMakerIDEncoder, FuseModule, PhotoMakerCLIPEncoder from toolkit.photomaker import PhotoMakerIDEncoder, FuseModule, PhotoMakerCLIPEncoder
from toolkit.saving import load_ip_adapter_model from toolkit.saving import load_ip_adapter_model
@@ -74,6 +75,7 @@ class CustomAdapter(torch.nn.Module):
self.clip_image_processor = self.image_processor self.clip_image_processor = self.image_processor
self.clip_fusion_module: CLIPFusionModule = None self.clip_fusion_module: CLIPFusionModule = None
self.ilora_module: InstantLoRAModule = None
self.setup_adapter() self.setup_adapter()
@@ -106,6 +108,15 @@ class CustomAdapter(torch.nn.Module):
vision_hidden_size=self.vision_encoder.config.hidden_size, vision_hidden_size=self.vision_encoder.config.hidden_size,
vision_tokens=vision_tokens vision_tokens=vision_tokens
) )
elif self.adapter_type == 'ilora':
vision_tokens = ((self.vision_encoder.config.image_size // self.vision_encoder.config.patch_size) ** 2)
if self.config.image_encoder_arch == 'clip':
vision_tokens = vision_tokens + 1
self.ilora_module = InstantLoRAModule(
vision_tokens=vision_tokens,
vision_hidden_size=self.vision_encoder.config.hidden_size,
sd=self.sd_ref()
)
else: else:
raise ValueError(f"unknown adapter type: {self.adapter_type}") raise ValueError(f"unknown adapter type: {self.adapter_type}")
@@ -283,6 +294,9 @@ class CustomAdapter(torch.nn.Module):
if 'fuse_module' in state_dict: if 'fuse_module' in state_dict:
self.fuse_module.load_state_dict(state_dict['fuse_module'], strict=strict) self.fuse_module.load_state_dict(state_dict['fuse_module'], strict=strict)
if 'ilora' in state_dict:
self.ilora_module.load_state_dict(state_dict['ilora'], strict=strict)
pass pass
def state_dict(self) -> OrderedDict: def state_dict(self) -> OrderedDict:
@@ -301,6 +315,11 @@ class CustomAdapter(torch.nn.Module):
state_dict["vision_encoder"] = self.vision_encoder.state_dict() state_dict["vision_encoder"] = self.vision_encoder.state_dict()
state_dict["clip_fusion"] = self.clip_fusion_module.state_dict() state_dict["clip_fusion"] = self.clip_fusion_module.state_dict()
return state_dict return state_dict
elif self.adapter_type == 'ilora':
if self.config.train_image_encoder:
state_dict["vision_encoder"] = self.vision_encoder.state_dict()
state_dict["ilora"] = self.ilora_module.state_dict()
return state_dict
else: else:
raise NotImplementedError raise NotImplementedError
@@ -309,7 +328,7 @@ class CustomAdapter(torch.nn.Module):
prompt: Union[List[str], str], prompt: Union[List[str], str],
is_unconditional: bool = False, is_unconditional: bool = False,
): ):
if self.adapter_type == 'clip_fusion': if self.adapter_type == 'clip_fusion' or self.adapter_type == 'ilora':
return prompt return prompt
elif self.adapter_type == 'photo_maker': elif self.adapter_type == 'photo_maker':
if is_unconditional: if is_unconditional:
@@ -408,7 +427,7 @@ class CustomAdapter(torch.nn.Module):
has_been_preprocessed=False, has_been_preprocessed=False,
is_unconditional=False is_unconditional=False
) -> PromptEmbeds: ) -> PromptEmbeds:
if self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion': if self.adapter_type == 'photo_maker' or self.adapter_type == 'clip_fusion' or self.adapter_type == 'ilora':
if is_unconditional: if is_unconditional:
# we dont condition the negative embeds for photo maker # we dont condition the negative embeds for photo maker
return prompt_embeds.clone() return prompt_embeds.clone()
@@ -459,7 +478,7 @@ class CustomAdapter(torch.nn.Module):
self.token_mask self.token_mask
) )
return prompt_embeds return prompt_embeds
elif self.adapter_type == 'clip_fusion': elif self.adapter_type == 'clip_fusion' or self.adapter_type == 'ilora':
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
if is_training and self.config.train_image_encoder: if is_training and self.config.train_image_encoder:
self.vision_encoder.train() self.vision_encoder.train()
@@ -480,11 +499,17 @@ class CustomAdapter(torch.nn.Module):
if not is_training or not self.config.train_image_encoder: if not is_training or not self.config.train_image_encoder:
img_embeds = img_embeds.detach() img_embeds = img_embeds.detach()
prompt_embeds.text_embeds = self.clip_fusion_module( if self.adapter_type == 'ilora':
prompt_embeds.text_embeds, self.ilora_module.img_embeds = img_embeds
img_embeds
) return prompt_embeds
return prompt_embeds else:
prompt_embeds.text_embeds = self.clip_fusion_module(
prompt_embeds.text_embeds,
img_embeds
)
return prompt_embeds
else: else:
@@ -499,5 +524,9 @@ class CustomAdapter(torch.nn.Module):
yield from self.clip_fusion_module.parameters(recurse) yield from self.clip_fusion_module.parameters(recurse)
if self.config.train_image_encoder: if self.config.train_image_encoder:
yield from self.vision_encoder.parameters(recurse) yield from self.vision_encoder.parameters(recurse)
elif self.config.type == 'ilora':
yield from self.ilora_module.parameters(recurse)
if self.config.train_image_encoder:
yield from self.vision_encoder.parameters(recurse)
else: else:
raise NotImplementedError raise NotImplementedError

View File

@@ -655,6 +655,21 @@ class ClipImageFileItemDTOMixin:
else: else:
self.clip_image_tensor = transforms.ToTensor()(img) self.clip_image_tensor = transforms.ToTensor()(img)
# random crop
# if self.dataset_config.clip_image_random_crop:
# # crop up to 20% on all sides. Keep is square
# crop_percent = random.randint(0, 20) / 100
# crop_width = int(self.clip_image_tensor.shape[2] * crop_percent)
# crop_height = int(self.clip_image_tensor.shape[1] * crop_percent)
# crop_left = random.randint(0, crop_width)
# crop_top = random.randint(0, crop_height)
# crop_right = self.clip_image_tensor.shape[2] - crop_width - crop_left
# crop_bottom = self.clip_image_tensor.shape[1] - crop_height - crop_top
# if len(self.clip_image_tensor.shape) == 3:
# self.clip_image_tensor = self.clip_image_tensor[:, crop_top:-crop_bottom, crop_left:-crop_right]
# elif len(self.clip_image_tensor.shape) == 4:
# self.clip_image_tensor = self.clip_image_tensor[:, :, crop_top:-crop_bottom, crop_left:-crop_right]
if self.clip_image_processor is not None: if self.clip_image_processor is not None:
# run it # run it
tensors_0_1 = self.clip_image_tensor.to(dtype=torch.float16) tensors_0_1 = self.clip_image_tensor.to(dtype=torch.float16)

View File

@@ -1,3 +1,5 @@
import random
import torch import torch
import sys import sys
@@ -39,6 +41,9 @@ from transformers import ViTHybridImageProcessor, ViTHybridForImageClassificatio
from transformers import ViTFeatureExtractor, ViTForImageClassification from transformers import ViTFeatureExtractor, ViTForImageClassification
# gradient checkpointing
from torch.utils.checkpoint import checkpoint
import torch.nn.functional as F import torch.nn.functional as F
@@ -166,6 +171,8 @@ class IPAdapter(torch.nn.Module):
self.device = self.sd_ref().unet.device self.device = self.sd_ref().unet.device
self.preprocessor: Optional[CLIPImagePreProcessor] = None self.preprocessor: Optional[CLIPImagePreProcessor] = None
self.input_size = 224 self.input_size = 224
self.clip_noise_zero = True
self.unconditional: torch.Tensor = None
if self.config.image_encoder_arch == 'clip' or self.config.image_encoder_arch == 'clip+': if self.config.image_encoder_arch == 'clip' or self.config.image_encoder_arch == 'clip+':
try: try:
self.clip_image_processor = CLIPImageProcessor.from_pretrained(adapter_config.image_encoder_path) self.clip_image_processor = CLIPImageProcessor.from_pretrained(adapter_config.image_encoder_path)
@@ -236,6 +243,16 @@ class IPAdapter(torch.nn.Module):
self.input_size = self.image_encoder.config.image_size self.input_size = self.image_encoder.config.image_size
if self.config.quad_image: # 4x4 image
# self.clip_image_processor.config
# We do a 3x downscale of the image, so we need to adjust the input size
preprocessor_input_size = self.image_encoder.config.image_size * 2
# update the preprocessor so images come in at the right size
self.clip_image_processor.size['shortest_edge'] = preprocessor_input_size
self.clip_image_processor.crop_size['height'] = preprocessor_input_size
self.clip_image_processor.crop_size['width'] = preprocessor_input_size
if self.config.image_encoder_arch == 'clip+': if self.config.image_encoder_arch == 'clip+':
# self.clip_image_processor.config # self.clip_image_processor.config
# We do a 3x downscale of the image, so we need to adjust the input size # We do a 3x downscale of the image, so we need to adjust the input size
@@ -349,6 +366,15 @@ class IPAdapter(torch.nn.Module):
self.image_encoder.train() self.image_encoder.train()
self.image_encoder.requires_grad_(True) self.image_encoder.requires_grad_(True)
# premake a unconditional
zerod = torch.zeros(1, 3, self.input_size, self.input_size, device=self.device, dtype=torch.float16)
self.unconditional = self.clip_image_processor(
images=zerod,
return_tensors="pt",
do_resize=True,
do_rescale=False,
).pixel_values
def to(self, *args, **kwargs): def to(self, *args, **kwargs):
super().to(*args, **kwargs) super().to(*args, **kwargs)
self.image_encoder.to(*args, **kwargs) self.image_encoder.to(*args, **kwargs)
@@ -358,20 +384,23 @@ class IPAdapter(torch.nn.Module):
self.preprocessor.to(*args, **kwargs) self.preprocessor.to(*args, **kwargs)
return self return self
def load_ip_adapter(self, state_dict: Union[OrderedDict, dict]): # def load_ip_adapter(self, state_dict: Union[OrderedDict, dict]):
self.image_proj_model.load_state_dict(state_dict["image_proj"]) # self.image_proj_model.load_state_dict(state_dict["image_proj"])
ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) # ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
ip_layers.load_state_dict(state_dict["ip_adapter"]) # ip_layers.load_state_dict(state_dict["ip_adapter"])
if self.config.train_image_encoder and 'image_encoder' in state_dict: # if self.config.train_image_encoder and 'image_encoder' in state_dict:
self.image_encoder.load_state_dict(state_dict["image_encoder"]) # self.image_encoder.load_state_dict(state_dict["image_encoder"])
if self.preprocessor is not None and 'preprocessor' in state_dict: # if self.preprocessor is not None and 'preprocessor' in state_dict:
self.preprocessor.load_state_dict(state_dict["preprocessor"]) # self.preprocessor.load_state_dict(state_dict["preprocessor"])
# def load_state_dict(self, state_dict: Union[OrderedDict, dict]): # def load_state_dict(self, state_dict: Union[OrderedDict, dict]):
# self.load_ip_adapter(state_dict) # self.load_ip_adapter(state_dict)
def state_dict(self) -> OrderedDict: def state_dict(self) -> OrderedDict:
state_dict = OrderedDict() state_dict = OrderedDict()
if self.config.train_only_image_encoder:
return self.image_encoder.state_dict()
state_dict["image_proj"] = self.image_proj_model.state_dict() state_dict["image_proj"] = self.image_proj_model.state_dict()
state_dict["ip_adapter"] = self.adapter_modules.state_dict() state_dict["ip_adapter"] = self.adapter_modules.state_dict()
if self.config.train_image_encoder: if self.config.train_image_encoder:
@@ -402,13 +431,28 @@ class IPAdapter(torch.nn.Module):
# clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] # clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
# return clip_image_embeds # return clip_image_embeds
def to(self, *args, **kwargs):
super().to(*args, **kwargs)
self.image_encoder.to(*args, **kwargs)
self.image_proj_model.to(*args, **kwargs)
self.adapter_modules.to(*args, **kwargs)
if self.preprocessor is not None:
self.preprocessor.to(*args, **kwargs)
return self
def get_clip_image_embeds_from_tensors( def get_clip_image_embeds_from_tensors(
self, self,
tensors_0_1: torch.Tensor, tensors_0_1: torch.Tensor,
drop=False, drop=False,
is_training=False, is_training=False,
has_been_preprocessed=False has_been_preprocessed=False,
quad_count=4,
) -> torch.Tensor: ) -> torch.Tensor:
if self.sd_ref().unet.device != self.device:
self.to(self.sd_ref().unet.device)
if self.sd_ref().unet.device != self.image_encoder.device:
self.to(self.sd_ref().unet.device)
if not self.config.train:
is_training = False
with torch.no_grad(): with torch.no_grad():
# on training the clip image is created in the dataloader # on training the clip image is created in the dataloader
if not has_been_preprocessed: if not has_been_preprocessed:
@@ -417,11 +461,19 @@ class IPAdapter(torch.nn.Module):
tensors_0_1 = tensors_0_1.unsqueeze(0) tensors_0_1 = tensors_0_1.unsqueeze(0)
# training tensors are 0 - 1 # training tensors are 0 - 1
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16) tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
# if images are out of this range throw error # if images are out of this range throw error
if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3: if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format( raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
tensors_0_1.min(), tensors_0_1.max() tensors_0_1.min(), tensors_0_1.max()
)) ))
# unconditional
if drop:
if self.clip_noise_zero:
tensors_0_1 = torch.rand_like(tensors_0_1).detach()
else:
tensors_0_1 = torch.zeros_like(tensors_0_1).detach()
# tensors_0_1 = tensors_0_1 * 0
clip_image = self.clip_image_processor( clip_image = self.clip_image_processor(
images=tensors_0_1, images=tensors_0_1,
return_tensors="pt", return_tensors="pt",
@@ -429,10 +481,42 @@ class IPAdapter(torch.nn.Module):
do_rescale=False, do_rescale=False,
).pixel_values ).pixel_values
else: else:
clip_image = tensors_0_1 if drop:
# scale the noise down
if self.clip_noise_zero:
tensors_0_1 = torch.rand_like(tensors_0_1).detach()
else:
tensors_0_1 = torch.zeros_like(tensors_0_1).detach()
# tensors_0_1 = tensors_0_1 * 0
mean = torch.tensor(self.clip_image_processor.image_mean).to(
self.device, dtype=get_torch_dtype(self.sd_ref().dtype)
).detach()
std = torch.tensor(self.clip_image_processor.image_std).to(
self.device, dtype=get_torch_dtype(self.sd_ref().dtype)
).detach()
tensors_0_1 = torch.clip((255. * tensors_0_1), 0, 255).round() / 255.0
clip_image = (tensors_0_1 - mean.view([1, 3, 1, 1])) / std.view([1, 3, 1, 1])
else:
clip_image = tensors_0_1
clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach() clip_image = clip_image.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)).detach()
if drop:
clip_image = clip_image * 0 if self.config.quad_image:
# split the 4x4 grid and stack on batch
ci1, ci2 = clip_image.chunk(2, dim=2)
ci1, ci3 = ci1.chunk(2, dim=3)
ci2, ci4 = ci2.chunk(2, dim=3)
to_cat = []
for i, ci in enumerate([ci1, ci2, ci3, ci4]):
if i < quad_count:
to_cat.append(ci)
else:
break
clip_image = torch.cat(to_cat, dim=0).detach()
# if drop:
# clip_image = clip_image * 0
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
if is_training: if is_training:
self.image_encoder.train() self.image_encoder.train()
@@ -457,6 +541,20 @@ class IPAdapter(torch.nn.Module):
clip_image_embeds = clip_output.hidden_states[-2] clip_image_embeds = clip_output.hidden_states[-2]
else: else:
clip_image_embeds = clip_output.image_embeds clip_image_embeds = clip_output.image_embeds
if self.config.quad_image:
# get the outputs of the quat
chunks = clip_image_embeds.chunk(quad_count, dim=0)
chunk_sum = torch.zeros_like(chunks[0])
for chunk in chunks:
chunk_sum = chunk_sum + chunk
# get the mean of them
clip_image_embeds = chunk_sum / quad_count
if not is_training:
clip_image_embeds = clip_image_embeds.detach()
return clip_image_embeds return clip_image_embeds
# use drop for prompt dropout, or negatives # use drop for prompt dropout, or negatives
@@ -467,6 +565,9 @@ class IPAdapter(torch.nn.Module):
return embeddings return embeddings
def parameters(self, recurse: bool = True) -> Iterator[Parameter]: def parameters(self, recurse: bool = True) -> Iterator[Parameter]:
if self.config.train_only_image_encoder:
yield from self.image_encoder.parameters(recurse)
return
for attn_processor in self.adapter_modules: for attn_processor in self.adapter_modules:
yield from attn_processor.parameters(recurse) yield from attn_processor.parameters(recurse)
yield from self.image_proj_model.parameters(recurse) yield from self.image_proj_model.parameters(recurse)
@@ -561,17 +662,26 @@ class IPAdapter(torch.nn.Module):
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
strict = False strict = False
try: if 'ip_adapter' in state_dict:
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict) try:
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict) self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict)
except Exception as e: self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict)
print(e) except Exception as e:
print("could not load ip adapter weights, trying to merge in weights") print(e)
self.merge_in_weights(state_dict) print("could not load ip adapter weights, trying to merge in weights")
self.merge_in_weights(state_dict)
if self.config.train_image_encoder and 'image_encoder' in state_dict: if self.config.train_image_encoder and 'image_encoder' in state_dict:
self.image_encoder.load_state_dict(state_dict["image_encoder"], strict=strict) self.image_encoder.load_state_dict(state_dict["image_encoder"], strict=strict)
if self.preprocessor is not None and 'preprocessor' in state_dict: if self.preprocessor is not None and 'preprocessor' in state_dict:
self.preprocessor.load_state_dict(state_dict["preprocessor"], strict=strict) self.preprocessor.load_state_dict(state_dict["preprocessor"], strict=strict)
if self.config.train_only_image_encoder and 'ip_adapter' not in state_dict:
# we are loading pure clip weights.
self.image_encoder.load_state_dict(state_dict, strict=strict)
def enable_gradient_checkpointing(self): def enable_gradient_checkpointing(self):
self.image_encoder.gradient_checkpointing = True if hasattr(self.image_encoder, "enable_gradient_checkpointing"):
self.image_encoder.enable_gradient_checkpointing()
elif hasattr(self.image_encoder, 'gradient_checkpointing'):
self.image_encoder.gradient_checkpointing = True

View File

@@ -114,9 +114,9 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
# UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"] # UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel"]
# UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "ResnetBlock2D"] # UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "ResnetBlock2D"]
UNET_TARGET_REPLACE_MODULE = ["''UNet2DConditionModel''"] UNET_TARGET_REPLACE_MODULE = ["UNet2DConditionModel"]
# UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] # UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"]
UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["'UNet2DConditionModel'"] UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["UNet2DConditionModel"]
TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"]
LORA_PREFIX_UNET = "lora_unet" LORA_PREFIX_UNET = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te" LORA_PREFIX_TEXT_ENCODER = "lora_te"
@@ -155,6 +155,7 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
is_lorm: bool = False, is_lorm: bool = False,
ignore_if_contains = None, ignore_if_contains = None,
parameter_threshold: float = 0.0, parameter_threshold: float = 0.0,
attn_only: bool = False,
target_lin_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE, target_lin_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE,
target_conv_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3, target_conv_modules=LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3,
**kwargs **kwargs
@@ -243,6 +244,10 @@ class LoRASpecialNetwork(ToolkitNetworkMixin, LoRANetwork):
# for child_name, child_module in module.named_modules(): # for child_name, child_module in module.named_modules():
is_linear = module_name == 'LoRACompatibleLinear' is_linear = module_name == 'LoRACompatibleLinear'
is_conv2d = module_name == 'LoRACompatibleConv' is_conv2d = module_name == 'LoRACompatibleConv'
# check if attn in name
is_attention = "attentions" in name
if not is_attention and attn_only:
continue
if is_linear and self.lora_dim is None: if is_linear and self.lora_dim is None:
continue continue

View File

@@ -23,6 +23,8 @@ def get_meta_for_safetensors(meta: OrderedDict, name=None, add_software_info=Tru
# if not float, int, bool, or str, convert to json string # if not float, int, bool, or str, convert to json string
if not isinstance(value, str): if not isinstance(value, str):
save_meta[key] = json.dumps(value) save_meta[key] = json.dumps(value)
# add the pt format
save_meta["format"] = "pt"
return save_meta return save_meta

104
toolkit/models/ilora.py Normal file
View File

@@ -0,0 +1,104 @@
import weakref
import torch
import torch.nn as nn
from typing import TYPE_CHECKING
from toolkit.models.clip_fusion import ZipperBlock
if TYPE_CHECKING:
from toolkit.lora_special import LoRAModule
from toolkit.stable_diffusion_model import StableDiffusion
class InstantLoRAMidModule(torch.nn.Module):
def __init__(
self,
dim: int,
vision_tokens: int,
vision_hidden_size: int,
lora_module: 'LoRAModule',
instant_lora_module: 'InstantLoRAModule'
):
super(InstantLoRAMidModule, self).__init__()
self.dim = dim
self.vision_tokens = vision_tokens
self.vision_hidden_size = vision_hidden_size
self.lora_module_ref = weakref.ref(lora_module)
self.instant_lora_module_ref = weakref.ref(instant_lora_module)
self.zip = ZipperBlock(
in_size=self.vision_hidden_size,
in_tokens=self.vision_tokens,
out_size=self.dim,
out_tokens=1,
hidden_size=self.dim,
hidden_tokens=self.vision_tokens
)
def forward(self, x, *args, **kwargs):
# get the vector
img_embeds = self.instant_lora_module_ref().img_embeds
# project it
scaler = self.zip(img_embeds) # (batch_size, 1, dim)
# remove the channel dim
scaler = scaler.squeeze(1)
# double up if batch is 2x the size on x (cfg)
if x.shape[0] // 2 == scaler.shape[0]:
scaler = torch.cat([scaler, scaler], dim=0)
# multiply it by the scaler
try:
# reshape if needed
if len(x.shape) == 3:
scaler = scaler.unsqueeze(1)
x = x * scaler
except Exception as e:
print(e)
print(x.shape)
print(scaler.shape)
raise e
# apply tanh to limit values to -1 to 1
scaler = torch.tanh(scaler)
return x * scaler
class InstantLoRAModule(torch.nn.Module):
def __init__(
self,
vision_hidden_size: int,
vision_tokens: int,
sd: 'StableDiffusion'
):
super(InstantLoRAModule, self).__init__()
self.linear = torch.nn.Linear(2, 1)
self.sd_ref = weakref.ref(sd)
self.dim = sd.network.lora_dim
self.vision_hidden_size = vision_hidden_size
self.vision_tokens = vision_tokens
# stores the projection vector. Grabbed by modules
self.img_embeds: torch.Tensor = None
# disable merging in. It is slower on inference
self.sd_ref().network.can_merge_in = False
self.ilora_modules = torch.nn.ModuleList()
lora_modules = self.sd_ref().network.get_all_modules()
for lora_module in lora_modules:
# add a new mid module that will take the original forward and add a vector to it
# this will be used to add the vector to the original forward
mid_module = InstantLoRAMidModule(self.dim, self.vision_tokens, self.vision_hidden_size, lora_module, self)
self.ilora_modules.append(mid_module)
# replace the LoRA lora_mid
lora_module.lora_mid = mid_module.forward
# add a new mid module that will take the original forward and add a vector to it
# this will be used to add the vector to the original forward
def forward(self, x):
return self.linear(x)

View File

@@ -215,12 +215,17 @@ def save_ip_adapter_from_diffusers(
output_file: str, output_file: str,
meta: 'OrderedDict', meta: 'OrderedDict',
dtype=get_torch_dtype('fp16'), dtype=get_torch_dtype('fp16'),
direct_save: bool = False
): ):
# todo: test compatibility with non diffusers # todo: test compatibility with non diffusers
converted_state_dict = OrderedDict() converted_state_dict = OrderedDict()
for module_name, state_dict in combined_state_dict.items(): for module_name, state_dict in combined_state_dict.items():
for key, value in state_dict.items(): if direct_save:
converted_state_dict[f"{module_name}.{key}"] = value.detach().to('cpu', dtype=dtype) converted_state_dict[module_name] = state_dict.detach().to('cpu', dtype=dtype)
else:
for key, value in state_dict.items():
converted_state_dict[f"{module_name}.{key}"] = value.detach().to('cpu', dtype=dtype)
# make sure parent folder exists # make sure parent folder exists
os.makedirs(os.path.dirname(output_file), exist_ok=True) os.makedirs(os.path.dirname(output_file), exist_ok=True)
@@ -230,12 +235,15 @@ def save_ip_adapter_from_diffusers(
def load_ip_adapter_model( def load_ip_adapter_model(
path_to_file, path_to_file,
device: Union[str] = 'cpu', device: Union[str] = 'cpu',
dtype: torch.dtype = torch.float32 dtype: torch.dtype = torch.float32,
direct_load: bool = False
): ):
# check if it is safetensors or checkpoint # check if it is safetensors or checkpoint
if path_to_file.endswith('.safetensors'): if path_to_file.endswith('.safetensors'):
raw_state_dict = load_file(path_to_file, device) raw_state_dict = load_file(path_to_file, device)
combined_state_dict = OrderedDict() combined_state_dict = OrderedDict()
if direct_load:
return raw_state_dict
for combo_key, value in raw_state_dict.items(): for combo_key, value in raw_state_dict.items():
key_split = combo_key.split('.') key_split = combo_key.split('.')
module_name = key_split.pop(0) module_name = key_split.pop(0)