From 830e87cb8753b4d06dc3108d8215a7d6b3270fec Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sun, 24 Sep 2023 02:39:43 -0600 Subject: [PATCH] Added IP adapter training. Not functioning correctly yet --- .gitmodules | 3 + extensions_built_in/sd_trainer/SDTrainer.py | 13 +- jobs/process/BaseSDTrainProcess.py | 115 +++++++++----- repositories/ipadapter | 1 + toolkit/config_modules.py | 6 + toolkit/ip_adapter.py | 157 ++++++++++++++++++++ toolkit/network_mixins.py | 1 + toolkit/saving.py | 39 +++++ toolkit/stable_diffusion_model.py | 54 +++++-- 9 files changed, 336 insertions(+), 53 deletions(-) create mode 160000 repositories/ipadapter create mode 100644 toolkit/ip_adapter.py diff --git a/.gitmodules b/.gitmodules index ebf3e531..657cf28b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,3 +7,6 @@ [submodule "repositories/batch_annotator"] path = repositories/batch_annotator url = https://github.com/ostris/batch-annotator +[submodule "repositories/ipadapter"] + path = repositories/ipadapter + url = https://github.com/tencent-ailab/IP-Adapter.git diff --git a/extensions_built_in/sd_trainer/SDTrainer.py b/extensions_built_in/sd_trainer/SDTrainer.py index ed4625f8..eb3b057c 100644 --- a/extensions_built_in/sd_trainer/SDTrainer.py +++ b/extensions_built_in/sd_trainer/SDTrainer.py @@ -2,9 +2,11 @@ import os.path from collections import OrderedDict from PIL import Image +from diffusers import T2IAdapter from torch.utils.data import DataLoader from toolkit.data_transfer_object.data_loader import DataLoaderBatchDTO +from toolkit.ip_adapter import IPAdapter from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork from toolkit.train_tools import get_torch_dtype, apply_snr_weight @@ -115,13 +117,19 @@ class SDTrainer(BaseSDTrainProcess): conditional_embeds = conditional_embeds.detach() # flush() pred_kwargs = {} - if self.adapter: + if self.adapter and isinstance(self.adapter, T2IAdapter): down_block_additional_residuals = self.adapter(adapter_images) down_block_additional_residuals = [ sample.to(dtype=dtype) for sample in down_block_additional_residuals ] pred_kwargs['down_block_additional_residuals'] = down_block_additional_residuals + if self.adapter and isinstance(self.adapter, IPAdapter): + with torch.no_grad(): + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(adapter_images) + conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds) + + noise_pred = self.sd.predict_noise( latents=noisy_latents.to(self.device_torch, dtype=dtype), conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype), @@ -164,6 +172,9 @@ class SDTrainer(BaseSDTrainProcess): loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma) loss = loss.mean() + # check if nan + if torch.isnan(loss): + raise ValueError("loss is nan") # IMPORTANT if gradient checkpointing do not leave with network when doing backward # it will destroy the gradients. This is because the network is a context manager diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index 90cedbdb..40dded91 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -15,6 +15,7 @@ from toolkit.basic import value_map from toolkit.data_loader import get_dataloader_from_datasets from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO from toolkit.embedding import Embedding +from toolkit.ip_adapter import IPAdapter from toolkit.lora_special import LoRASpecialNetwork from toolkit.lycoris_special import LycorisSpecialNetwork from toolkit.network_mixins import Network @@ -22,7 +23,8 @@ from toolkit.optimizer import get_optimizer from toolkit.paths import CONFIG_ROOT from toolkit.progress_bar import ToolkitProgressBar from toolkit.sampler import get_sampler -from toolkit.saving import save_t2i_from_diffusers, load_t2i_model +from toolkit.saving import save_t2i_from_diffusers, load_t2i_model, save_ip_adapter_from_diffusers, \ + load_ip_adapter_model from toolkit.scheduler import get_lr_scheduler from toolkit.sd_device_states_presets import get_train_sd_device_state_preset @@ -118,9 +120,11 @@ class BaseSDTrainProcess(BaseTrainProcess): # to hold network if there is one self.network: Union[Network, None] = None - self.adapter: Union[T2IAdapter, None] = None + self.adapter: Union[T2IAdapter, IPAdapter, None] = None self.embedding: Union[Embedding, None] = None + is_training_adapter = self.adapter_config is not None and self.adapter_config.train + # get the device state preset based on what we are training self.train_device_state_preset = get_train_sd_device_state_preset( device=self.device_torch, @@ -128,17 +132,17 @@ class BaseSDTrainProcess(BaseTrainProcess): train_text_encoder=self.train_config.train_text_encoder, cached_latents=self.is_latents_cached, train_lora=self.network_config is not None, - train_adapter=self.adapter_config is not None, + train_adapter=is_training_adapter, train_embedding=self.embed_config is not None, ) # fine_tuning here is for training actual SD network, not LoRA, embeddings, etc. it is (Dreambooth, etc) self.is_fine_tuning = True - if self.network_config is not None or self.adapter_config is not None or self.embed_config is not None: + if self.network_config is not None or is_training_adapter or self.embed_config is not None: self.is_fine_tuning = False self.named_lora = False - if self.embed_config is not None or self.adapter_config is not None: + if self.embed_config is not None or is_training_adapter: self.named_lora = True def post_process_generate_image_config_list(self, generate_image_config_list: List[GenerateImageConfig]): # override in subclass @@ -179,7 +183,7 @@ class BaseSDTrainProcess(BaseTrainProcess): ) extra_args = {} - if self.adapter_config is not None: + if self.adapter_config is not None and self.adapter_config.test_img_path is not None: extra_args['adapter_image_path'] = self.adapter_config.test_img_path gen_img_config_list.append(GenerateImageConfig( @@ -318,22 +322,33 @@ class BaseSDTrainProcess(BaseTrainProcess): emb_file_path = os.path.splitext(emb_file_path)[0] + ".pt" self.embedding.save(emb_file_path) - if self.adapter is not None: + if self.adapter is not None and self.adapter_config.train: adapter_name = self.job.name if self.network_config is not None or self.embedding is not None: # add _lora to name - adapter_name += '_t2i' + if self.adapter_config.type == 't2i': + adapter_name += '_t2i' + else: + adapter_name += '_ip' filename = f'{adapter_name}{step_num}.safetensors' file_path = os.path.join(self.save_root, filename) # save adapter state_dict = self.adapter.state_dict() - save_t2i_from_diffusers( - state_dict, - output_file=file_path, - meta=save_meta, - dtype=get_torch_dtype(self.save_config.dtype) - ) + if self.adapter_config.type == 't2i': + save_t2i_from_diffusers( + state_dict, + output_file=file_path, + meta=save_meta, + dtype=get_torch_dtype(self.save_config.dtype) + ) + else: + save_ip_adapter_from_diffusers( + state_dict, + output_file=file_path, + meta=save_meta, + dtype=get_torch_dtype(self.save_config.dtype) + ) else: self.sd.save( file_path, @@ -527,6 +542,50 @@ class BaseSDTrainProcess(BaseTrainProcess): return noisy_latents, noise, timesteps, conditioned_prompts, imgs + def setup_adapter(self): + dtype = get_torch_dtype(self.train_config.dtype) + is_t2i = self.adapter_config.type == 't2i' + if is_t2i: + self.adapter = T2IAdapter( + in_channels=self.adapter_config.in_channels, + channels=self.adapter_config.channels, + num_res_blocks=self.adapter_config.num_res_blocks, + downscale_factor=self.adapter_config.downscale_factor, + adapter_type=self.adapter_config.adapter_type, + ) + else: + self.adapter = IPAdapter( + sd=self.sd, + adapter_config=self.adapter_config, + ) + self.adapter.to(self.device_torch, dtype=dtype) + # t2i adapter + suffix = 't2i' if is_t2i else 'ip' + adapter_name = self.name + if self.network_config is not None: + adapter_name = f"{adapter_name}_{suffix}" + latest_save_path = self.get_latest_save_path(adapter_name) + if latest_save_path is not None: + # load adapter from path + print(f"Loading adapter from {latest_save_path}") + if is_t2i: + loaded_state_dict = load_t2i_model( + latest_save_path, + self.device, + dtype=dtype + ) + else: + loaded_state_dict = load_ip_adapter_model( + latest_save_path, + self.device, + dtype=dtype + ) + self.adapter.load_state_dict(loaded_state_dict) + if self.adapter_config.train: + self.load_training_state_from_metadata(latest_save_path) + # set trainable params + self.sd.adapter = self.adapter + def run(self): # torch.autograd.set_detect_anomaly(True) # run base process run @@ -741,35 +800,12 @@ class BaseSDTrainProcess(BaseTrainProcess): flush() if self.adapter_config is not None: - self.adapter = T2IAdapter( - in_channels=self.adapter_config.in_channels, - channels=self.adapter_config.channels, - num_res_blocks=self.adapter_config.num_res_blocks, - downscale_factor=self.adapter_config.downscale_factor, - adapter_type=self.adapter_config.adapter_type, - ) - self.adapter.to(self.device_torch, dtype=dtype) - # t2i adapter - adapter_name = self.name - if self.network_config is not None: - adapter_name = f"{adapter_name}_t2i" - latest_save_path = self.get_latest_save_path(adapter_name) - if latest_save_path is not None: - # load adapter from path - print(f"Loading adapter from {latest_save_path}") - loaded_state_dict = load_t2i_model( - latest_save_path, - self.device, - dtype=dtype - ) - self.adapter.load_state_dict(loaded_state_dict) - self.load_training_state_from_metadata(latest_save_path) + self.setup_adapter() # set trainable params params.append({ 'params': self.adapter.parameters(), 'lr': self.train_config.adapter_lr }) - self.sd.adapter = self.adapter flush() params = self.load_additional_training_modules(params) @@ -785,6 +821,9 @@ class BaseSDTrainProcess(BaseTrainProcess): unet_lr=self.train_config.lr, default_lr=self.train_config.lr ) + # we may be using it for prompt injections + if self.adapter_config is not None: + self.setup_adapter() flush() ### HOOK ### params = self.hook_add_extra_train_params(params) diff --git a/repositories/ipadapter b/repositories/ipadapter new file mode 160000 index 00000000..d8ab37c4 --- /dev/null +++ b/repositories/ipadapter @@ -0,0 +1 @@ +Subproject commit d8ab37c421c1ab95d15abe094e8266a6d01e26ef diff --git a/toolkit/config_modules.py b/toolkit/config_modules.py index 7e907ff1..6e0dc2f1 100644 --- a/toolkit/config_modules.py +++ b/toolkit/config_modules.py @@ -61,8 +61,11 @@ class NetworkConfig: self.dropout: Union[float, None] = kwargs.get('dropout', None) +AdapterTypes = Literal['t2i', 'ip', 'ip+'] + class AdapterConfig: def __init__(self, **kwargs): + self.type: AdapterTypes = kwargs.get('type', 't2i') # t2i, ip self.in_channels: int = kwargs.get('in_channels', 3) self.channels: List[int] = kwargs.get('channels', [320, 640, 1280, 1280]) self.num_res_blocks: int = kwargs.get('num_res_blocks', 2) @@ -70,6 +73,9 @@ class AdapterConfig: self.adapter_type: str = kwargs.get('adapter_type', 'full_adapter') self.image_dir: str = kwargs.get('image_dir', None) self.test_img_path: str = kwargs.get('test_img_path', None) + self.train: str = kwargs.get('train', False) + self.image_encoder_path: str = kwargs.get('image_encoder_path', None) + self.name_or_path = kwargs.get('name_or_path', None) class EmbeddingConfig: diff --git a/toolkit/ip_adapter.py b/toolkit/ip_adapter.py new file mode 100644 index 00000000..f252af28 --- /dev/null +++ b/toolkit/ip_adapter.py @@ -0,0 +1,157 @@ +import torch +import sys + +from PIL import Image +from torch.nn import Parameter +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection + +from toolkit.paths import REPOS_ROOT +from toolkit.train_tools import get_torch_dtype + +sys.path.append(REPOS_ROOT) +from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List +from collections import OrderedDict +from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor +from ipadapter.ip_adapter.ip_adapter import ImageProjModel +from ipadapter.ip_adapter.resampler import Resampler +from toolkit.config_modules import AdapterConfig +from toolkit.prompt_utils import PromptEmbeds +import weakref + +if TYPE_CHECKING: + from toolkit.stable_diffusion_model import StableDiffusion + + +# loosely based on # ref https://github.com/tencent-ailab/IP-Adapter/blob/main/tutorial_train.py +class IPAdapter(torch.nn.Module): + """IP-Adapter""" + + def __init__(self, sd: 'StableDiffusion', adapter_config: 'AdapterConfig'): + super().__init__() + self.config = adapter_config + self.sd_ref: weakref.ref = weakref.ref(sd) + self.clip_image_processor = CLIPImageProcessor() + self.device = self.sd_ref().unet.device + self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(adapter_config.image_encoder_path) + if adapter_config.type == 'ip': + # ip-adapter + image_proj_model = ImageProjModel( + cross_attention_dim=sd.unet.config['cross_attention_dim'], + clip_embeddings_dim=self.image_encoder.config.projection_dim, + clip_extra_context_tokens=4, + ) + elif adapter_config.type == 'ip+': + # ip-adapter-plus + num_tokens = 16 + image_proj_model = Resampler( + dim=sd.unet.config['cross_attention_dim'], + depth=4, + dim_head=64, + heads=12, + num_queries=num_tokens, + embedding_dim=self.image_encoder.config.hidden_size, + output_dim=sd.unet.config['cross_attention_dim'], + ff_mult=4 + ) + else: + raise ValueError(f"unknown adapter type: {adapter_config.type}") + + # init adapter modules + attn_procs = {} + unet_sd = sd.unet.state_dict() + for name in sd.unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else sd.unet.config['cross_attention_dim'] + if name.startswith("mid_block"): + hidden_size = sd.unet.config['block_out_channels'][-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(sd.unet.config['block_out_channels']))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = sd.unet.config['block_out_channels'][block_id] + else: + # they didnt have this, but would lead to undefined below + raise ValueError(f"unknown attn processor name: {name}") + if cross_attention_dim is None: + attn_procs[name] = AttnProcessor() + else: + layer_name = name.split(".processor")[0] + weights = { + "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"], + "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"], + } + attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) + attn_procs[name].load_state_dict(weights) + sd.unet.set_attn_processor(attn_procs) + adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values()) + + sd.adapter = self + self.unet_ref: weakref.ref = weakref.ref(sd.unet) + self.image_proj_model = image_proj_model + self.adapter_modules = adapter_modules + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.image_encoder.to(*args, **kwargs) + self.image_proj_model.to(*args, **kwargs) + self.adapter_modules.to(*args, **kwargs) + return self + + def load_ip_adapter(self, state_dict: Union[OrderedDict, dict]): + self.image_proj_model.load_state_dict(state_dict["image_proj"]) + ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) + ip_layers.load_state_dict(state_dict["ip_adapter"]) + + def state_dict(self) -> OrderedDict: + state_dict = OrderedDict() + state_dict["image_proj"] = self.image_proj_model.state_dict() + state_dict["ip_adapter"] = self.adapter_modules.state_dict() + return state_dict + + def set_scale(self, scale): + for attn_processor in self.pipe.unet.attn_processors.values(): + if isinstance(attn_processor, IPAttnProcessor): + attn_processor.scale = scale + + @torch.no_grad() + def get_clip_image_embeds_from_pil(self, pil_image: Union[Image.Image, List[Image.Image]], drop=False) -> torch.Tensor: + # todo: add support for sdxl + if isinstance(pil_image, Image.Image): + pil_image = [pil_image] + clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values + clip_image = clip_image.to(self.device, dtype=torch.float16) + if drop: + clip_image = clip_image * 0 + clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] + return clip_image_embeds + + @torch.no_grad() + def get_clip_image_embeds_from_tensors(self, tensors_0_1: torch.Tensor, drop=False) -> torch.Tensor: + # tensors should be 0-1 + # todo: add support for sdxl + if tensors_0_1.ndim == 3: + tensors_0_1 = tensors_0_1.unsqueeze(0) + tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16) + clip_image = self.clip_image_processor(images=tensors_0_1, return_tensors="pt", do_resize=False).pixel_values + clip_image = clip_image.to(self.device, dtype=torch.float16) + if drop: + clip_image = clip_image * 0 + clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] + return clip_image_embeds + + # use drop for prompt dropout, or negatives + def forward(self, embeddings: PromptEmbeds, clip_image_embeds: torch.Tensor) -> PromptEmbeds: + clip_image_embeds = clip_image_embeds.detach() + clip_image_embeds = clip_image_embeds.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) + image_prompt_embeds = self.image_proj_model(clip_image_embeds.detach()) + embeddings.text_embeds = torch.cat([embeddings.text_embeds, image_prompt_embeds], dim=1) + return embeddings + + def parameters(self, recurse: bool = True) -> Iterator[Parameter]: + for attn_processor in self.adapter_modules: + yield from attn_processor.parameters(recurse) + yield from self.image_proj_model.parameters(recurse) + + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): + self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict) + self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict) diff --git a/toolkit/network_mixins.py b/toolkit/network_mixins.py index b592670f..b5d88e8c 100644 --- a/toolkit/network_mixins.py +++ b/toolkit/network_mixins.py @@ -306,6 +306,7 @@ class ToolkitNetworkMixin: extra_dict = None return extra_dict + @torch.no_grad() def _update_torch_multiplier(self: Network): # builds a tensor for fast usage in the forward pass of the network modules # without having to set it in every single module every time it changes diff --git a/toolkit/saving.py b/toolkit/saving.py index 8d172b71..5a22b966 100644 --- a/toolkit/saving.py +++ b/toolkit/saving.py @@ -193,3 +193,42 @@ def load_t2i_model( # todo see if we need to convert dict converted_state_dict[key] = value.detach().to(device, dtype=dtype) return converted_state_dict + + +IP_ADAPTER_MODULES = ['image_proj', 'ip_adapter'] + +def save_ip_adapter_from_diffusers( + combined_state_dict: 'OrderedDict', + output_file: str, + meta: 'OrderedDict', + dtype=get_torch_dtype('fp16'), +): + # todo: test compatibility with non diffusers + converted_state_dict = OrderedDict() + for module_name, state_dict in combined_state_dict.items(): + for key, value in state_dict.items(): + converted_state_dict[f"{module_name}.{key}"] = value.detach().to('cpu', dtype=dtype) + + # make sure parent folder exists + os.makedirs(os.path.dirname(output_file), exist_ok=True) + save_file(converted_state_dict, output_file, metadata=meta) + + +def load_ip_adapter_model( + path_to_file, + device: Union[str] = 'cpu', + dtype: torch.dtype = torch.float32 +): + # check if it is safetensors or checkpoint + if path_to_file.endswith('.safetensors'): + raw_state_dict = load_file(path_to_file, device) + combined_state_dict = OrderedDict() + for combo_key, value in raw_state_dict.items(): + key_split = combo_key.split('.') + module_name = key_split.pop(0) + if module_name not in combined_state_dict: + combined_state_dict[module_name] = OrderedDict() + combined_state_dict[module_name]['.'.join(key_split)] = value.detach().to(device, dtype=dtype) + return combined_state_dict + else: + return torch.load(path_to_file, map_location=device) diff --git a/toolkit/stable_diffusion_model.py b/toolkit/stable_diffusion_model.py index 7cabe84c..1296a856 100644 --- a/toolkit/stable_diffusion_model.py +++ b/toolkit/stable_diffusion_model.py @@ -13,8 +13,9 @@ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import from safetensors.torch import save_file, load_file from torch.nn import Parameter from tqdm import tqdm -from torchvision.transforms import Resize +from torchvision.transforms import Resize, transforms +from toolkit.ip_adapter import IPAdapter from library.model_util import convert_unet_state_dict_to_sd, convert_text_encoder_state_dict_to_sd_v2, \ convert_vae_state_dict, load_vae from toolkit import train_tools @@ -123,7 +124,7 @@ class StableDiffusion: # to hold network if there is one self.network = None - self.adapter: Union['T2IAdapter', None] = None + self.adapter: Union['T2IAdapter', 'IPAdapter', None] = None self.is_xl = model_config.is_xl self.is_v2 = model_config.is_v2 @@ -302,12 +303,13 @@ class StableDiffusion: Pipe = StableDiffusionPipeline extra_args = {} - if self.adapter: - if self.is_xl: - Pipe = StableDiffusionXLAdapterPipeline - else: - Pipe = StableDiffusionAdapterPipeline - extra_args['adapter'] = self.adapter + if self.adapter is not None: + if isinstance(self.adapter, T2IAdapter): + if self.is_xl: + Pipe = StableDiffusionXLAdapterPipeline + else: + Pipe = StableDiffusionAdapterPipeline + extra_args['adapter'] = self.adapter # TODO add clip skip if self.is_xl: @@ -358,11 +360,19 @@ class StableDiffusion: gen_config = image_configs[i] extra = {} - if gen_config.adapter_image_path is not None: + if self.adapter is not None and gen_config.adapter_image_path is not None: validation_image = Image.open(gen_config.adapter_image_path).convert("RGB") - validation_image = validation_image.resize((gen_config.width * 2, gen_config.height * 2)) - extra['image'] = validation_image - extra['adapter_conditioning_scale'] = 1.0 + if isinstance(self.adapter, T2IAdapter): + # not sure why this is double?? + validation_image = validation_image.resize((gen_config.width * 2, gen_config.height * 2)) + extra['image'] = validation_image + extra['adapter_conditioning_scale'] = 1.0 + if isinstance(self.adapter, IPAdapter): + transform = transforms.Compose([ + transforms.Resize(gen_config.width, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.PILToTensor(), + ]) + validation_image = transform(validation_image) if self.network is not None: self.network.multiplier = gen_config.network_multiplier @@ -382,6 +392,14 @@ class StableDiffusion: unconditional_embeds, ) + if self.adapter is not None and isinstance(self.adapter, IPAdapter) and gen_config.adapter_image_path is not None: + # apply the image projection + conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image) + unconditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(validation_image, True) + conditional_embeds = self.adapter(conditional_embeds, conditional_clip_embeds) + unconditional_embeds = self.adapter(unconditional_embeds, unconditional_clip_embeds) + + # todo do we disable text encoder here as well if disabled for model, or only do that for training? if self.is_xl: # fix guidance rescale for sdxl @@ -931,10 +949,18 @@ class StableDiffusion: 'requires_grad': self.text_encoder.text_model.final_layer_norm.weight.requires_grad } if self.adapter is not None: + if isinstance(self.adapter, IPAdapter): + requires_grad = self.adapter.adapter_modules.training + adapter_device = self.unet.device + elif isinstance(self.adapter, T2IAdapter): + requires_grad = self.adapter.adapter.conv_in.weight.requires_grad + adapter_device = self.adapter.device + else: + raise ValueError(f"Unknown adapter type: {type(self.adapter)}") self.device_state['adapter'] = { 'training': self.adapter.training, - 'device': self.adapter.device, - 'requires_grad': self.adapter.adapter.conv_in.weight.requires_grad, + 'device': adapter_device, + 'requires_grad': requires_grad, } def restore_device_state(self):