import torch import sys from PIL import Image from torch.nn import Parameter from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from toolkit.paths import REPOS_ROOT from toolkit.train_tools import get_torch_dtype sys.path.append(REPOS_ROOT) from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List from collections import OrderedDict from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor from ipadapter.ip_adapter.ip_adapter import ImageProjModel from ipadapter.ip_adapter.resampler import Resampler from toolkit.config_modules import AdapterConfig from toolkit.prompt_utils import PromptEmbeds import weakref if TYPE_CHECKING: from toolkit.stable_diffusion_model import StableDiffusion # loosely based on # ref https://github.com/tencent-ailab/IP-Adapter/blob/main/tutorial_train.py class IPAdapter(torch.nn.Module): """IP-Adapter""" def __init__(self, sd: 'StableDiffusion', adapter_config: 'AdapterConfig'): super().__init__() self.config = adapter_config self.sd_ref: weakref.ref = weakref.ref(sd) self.clip_image_processor = CLIPImageProcessor() self.device = self.sd_ref().unet.device self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(adapter_config.image_encoder_path) if adapter_config.type == 'ip': # ip-adapter image_proj_model = ImageProjModel( cross_attention_dim=sd.unet.config['cross_attention_dim'], clip_embeddings_dim=self.image_encoder.config.projection_dim, clip_extra_context_tokens=4, ) elif adapter_config.type == 'ip+': # ip-adapter-plus num_tokens = 16 image_proj_model = Resampler( dim=sd.unet.config['cross_attention_dim'], depth=4, dim_head=64, heads=12, num_queries=num_tokens, embedding_dim=self.image_encoder.config.hidden_size, output_dim=sd.unet.config['cross_attention_dim'], ff_mult=4 ) else: raise ValueError(f"unknown adapter type: {adapter_config.type}") # init adapter modules attn_procs = {} unet_sd = sd.unet.state_dict() for name in sd.unet.attn_processors.keys(): cross_attention_dim = None if name.endswith("attn1.processor") else sd.unet.config['cross_attention_dim'] if name.startswith("mid_block"): hidden_size = sd.unet.config['block_out_channels'][-1] elif name.startswith("up_blocks"): block_id = int(name[len("up_blocks.")]) hidden_size = list(reversed(sd.unet.config['block_out_channels']))[block_id] elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = sd.unet.config['block_out_channels'][block_id] else: # they didnt have this, but would lead to undefined below raise ValueError(f"unknown attn processor name: {name}") if cross_attention_dim is None: attn_procs[name] = AttnProcessor() else: layer_name = name.split(".processor")[0] weights = { "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"], "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"], } attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) attn_procs[name].load_state_dict(weights) sd.unet.set_attn_processor(attn_procs) adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values()) sd.adapter = self self.unet_ref: weakref.ref = weakref.ref(sd.unet) self.image_proj_model = image_proj_model self.adapter_modules = adapter_modules def to(self, *args, **kwargs): super().to(*args, **kwargs) self.image_encoder.to(*args, **kwargs) self.image_proj_model.to(*args, **kwargs) self.adapter_modules.to(*args, **kwargs) return self def load_ip_adapter(self, state_dict: Union[OrderedDict, dict]): self.image_proj_model.load_state_dict(state_dict["image_proj"]) ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) ip_layers.load_state_dict(state_dict["ip_adapter"]) def state_dict(self) -> OrderedDict: state_dict = OrderedDict() state_dict["image_proj"] = self.image_proj_model.state_dict() state_dict["ip_adapter"] = self.adapter_modules.state_dict() return state_dict def set_scale(self, scale): for attn_processor in self.pipe.unet.attn_processors.values(): if isinstance(attn_processor, IPAttnProcessor): attn_processor.scale = scale @torch.no_grad() def get_clip_image_embeds_from_pil(self, pil_image: Union[Image.Image, List[Image.Image]], drop=False) -> torch.Tensor: # todo: add support for sdxl if isinstance(pil_image, Image.Image): pil_image = [pil_image] clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values clip_image = clip_image.to(self.device, dtype=torch.float16) if drop: clip_image = clip_image * 0 clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] return clip_image_embeds @torch.no_grad() def get_clip_image_embeds_from_tensors(self, tensors_0_1: torch.Tensor, drop=False) -> torch.Tensor: # tensors should be 0-1 # todo: add support for sdxl if tensors_0_1.ndim == 3: tensors_0_1 = tensors_0_1.unsqueeze(0) tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16) clip_image = self.clip_image_processor(images=tensors_0_1, return_tensors="pt", do_resize=False).pixel_values clip_image = clip_image.to(self.device, dtype=torch.float16) if drop: clip_image = clip_image * 0 clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] return clip_image_embeds # use drop for prompt dropout, or negatives def forward(self, embeddings: PromptEmbeds, clip_image_embeds: torch.Tensor) -> PromptEmbeds: clip_image_embeds = clip_image_embeds.detach() clip_image_embeds = clip_image_embeds.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) image_prompt_embeds = self.image_proj_model(clip_image_embeds.detach()) embeddings.text_embeds = torch.cat([embeddings.text_embeds, image_prompt_embeds], dim=1) return embeddings def parameters(self, recurse: bool = True) -> Iterator[Parameter]: for attn_processor in self.adapter_modules: yield from attn_processor.parameters(recurse) yield from self.image_proj_model.parameters(recurse) def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict) self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict)