Allow ip adapters to be much more variable in their creation

This commit is contained in:
Jaret Burkett
2023-12-20 06:18:33 -07:00
parent 82098e5d6e
commit dfb64b5957
4 changed files with 89 additions and 57 deletions

View File

@@ -787,25 +787,33 @@ class SDTrainer(BaseSDTrainProcess):
if self.adapter and isinstance(self.adapter, IPAdapter): if self.adapter and isinstance(self.adapter, IPAdapter):
with self.timer('encode_adapter_embeds'): with self.timer('encode_adapter_embeds'):
with torch.no_grad(): if has_adapter_img:
if has_adapter_img: conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( adapter_images.detach().to(self.device_torch, dtype=dtype),
adapter_images.detach().to(self.device_torch, dtype=dtype)) is_training=True
elif is_reg: )
# we will zero it out in the img embedder elif is_reg:
adapter_img = torch.zeros( # we will zero it out in the img embedder
(noisy_latents.shape[0], 3, 512, 512), adapter_img = torch.zeros(
device=self.device_torch, dtype=dtype (noisy_latents.shape[0], 3, 512, 512),
) device=self.device_torch, dtype=dtype
# drop will zero it out ).detach()
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors( # drop will zero it out
adapter_img, drop=True conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
) adapter_img,
else: drop=True,
raise ValueError("Adapter images now must be loaded with dataloader or be a reg image") is_training=True
)
else:
raise ValueError("Adapter images now must be loaded with dataloader or be a reg image")
if not self.adapter_config.train_image_encoder:
# we are not training the image encoder, so we need to detach the embeds
conditional_clip_embeds = conditional_clip_embeds.detach()
with self.timer('encode_adapter'): with self.timer('encode_adapter'):
conditional_embeds = self.adapter(conditional_embeds.detach(), conditional_clip_embeds.detach()) conditional_embeds = self.adapter(conditional_embeds.detach(), conditional_clip_embeds)
prior_pred = None prior_pred = None
if (has_adapter_img and self.assistant_adapter and match_adapter_assist) or (self.do_prior_prediction and not is_reg): if (has_adapter_img and self.assistant_adapter and match_adapter_assist) or (self.do_prior_prediction and not is_reg):

View File

@@ -870,6 +870,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
sd=self.sd, sd=self.sd,
adapter_config=self.adapter_config, adapter_config=self.adapter_config,
) )
if self.train_config.gradient_checkpointing:
self.adapter.enable_gradient_checkpointing()
self.adapter.to(self.device_torch, dtype=dtype) self.adapter.to(self.device_torch, dtype=dtype)
if latest_save_path is not None: if latest_save_path is not None:
# load adapter from path # load adapter from path

View File

@@ -142,6 +142,17 @@ class AdapterConfig:
self.image_encoder_path: str = kwargs.get('image_encoder_path', None) self.image_encoder_path: str = kwargs.get('image_encoder_path', None)
self.name_or_path = kwargs.get('name_or_path', None) self.name_or_path = kwargs.get('name_or_path', None)
num_tokens = kwargs.get('num_tokens', None)
if num_tokens is None and self.type.startswith('ip'):
if self.type == 'ip+':
num_tokens = 16
elif self.type == 'ip':
num_tokens = 4
self.num_tokens: int = num_tokens
self.train_image_encoder: bool = kwargs.get('train_image_encoder', False)
class EmbeddingConfig: class EmbeddingConfig:
def __init__(self, **kwargs): def __init__(self, **kwargs):

View File

@@ -12,7 +12,8 @@ from toolkit.train_tools import get_torch_dtype
sys.path.append(REPOS_ROOT) sys.path.append(REPOS_ROOT)
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List
from collections import OrderedDict from collections import OrderedDict
from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, AttnProcessor2_0 from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, \
AttnProcessor2_0
from ipadapter.ip_adapter.ip_adapter import ImageProjModel from ipadapter.ip_adapter.ip_adapter import ImageProjModel
from ipadapter.ip_adapter.resampler import Resampler from ipadapter.ip_adapter.resampler import Resampler
from toolkit.config_modules import AdapterConfig from toolkit.config_modules import AdapterConfig
@@ -25,6 +26,7 @@ if TYPE_CHECKING:
from transformers import ( from transformers import (
CLIPImageProcessor, CLIPImageProcessor,
CLIPVisionModelWithProjection, CLIPVisionModelWithProjection,
CLIPVisionModel
) )
import torch.nn.functional as F import torch.nn.functional as F
@@ -151,9 +153,10 @@ class IPAdapter(torch.nn.Module):
super().__init__() super().__init__()
self.config = adapter_config self.config = adapter_config
self.sd_ref: weakref.ref = weakref.ref(sd) self.sd_ref: weakref.ref = weakref.ref(sd)
self.clip_image_processor = CLIPImageProcessor() self.clip_image_processor = CLIPImageProcessor.from_pretrained(adapter_config.image_encoder_path)
self.device = self.sd_ref().unet.device self.device = self.sd_ref().unet.device
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(adapter_config.image_encoder_path) self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(adapter_config.image_encoder_path,
ignore_mismatched_sizes=True)
self.current_scale = 1.0 self.current_scale = 1.0
self.is_active = True self.is_active = True
if adapter_config.type == 'ip': if adapter_config.type == 'ip':
@@ -161,17 +164,16 @@ class IPAdapter(torch.nn.Module):
image_proj_model = ImageProjModel( image_proj_model = ImageProjModel(
cross_attention_dim=sd.unet.config['cross_attention_dim'], cross_attention_dim=sd.unet.config['cross_attention_dim'],
clip_embeddings_dim=self.image_encoder.config.projection_dim, clip_embeddings_dim=self.image_encoder.config.projection_dim,
clip_extra_context_tokens=4, clip_extra_context_tokens=self.config.num_tokens, # usually 4
) )
elif adapter_config.type == 'ip+': elif adapter_config.type == 'ip+':
# ip-adapter-plus # ip-adapter-plus
num_tokens = 16
image_proj_model = Resampler( image_proj_model = Resampler(
dim=sd.unet.config['cross_attention_dim'], dim=sd.unet.config['cross_attention_dim'],
depth=4, depth=4,
dim_head=64, dim_head=64,
heads=12, heads=12,
num_queries=num_tokens, num_queries=self.config.num_tokens, # usually 16
embedding_dim=self.image_encoder.config.hidden_size, embedding_dim=self.image_encoder.config.hidden_size,
output_dim=sd.unet.config['cross_attention_dim'], output_dim=sd.unet.config['cross_attention_dim'],
ff_mult=4 ff_mult=4
@@ -203,20 +205,12 @@ class IPAdapter(torch.nn.Module):
"to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"], "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
"to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"], "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
} }
if adapter_config.type == 'ip':
# ip-adapter
num_tokens = 4
elif adapter_config.type == 'ip+':
# ip-adapter-plus
num_tokens = 16
else:
raise ValueError(f"unknown adapter type: {adapter_config.type}")
attn_procs[name] = CustomIPAttentionProcessor( attn_procs[name] = CustomIPAttentionProcessor(
hidden_size=hidden_size, hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
scale=1.0, scale=1.0,
num_tokens=num_tokens, num_tokens=self.config.num_tokens,
adapter=self adapter=self
) )
attn_procs[name].load_state_dict(weights) attn_procs[name].load_state_dict(weights)
@@ -249,6 +243,8 @@ class IPAdapter(torch.nn.Module):
self.image_proj_model.load_state_dict(state_dict["image_proj"]) self.image_proj_model.load_state_dict(state_dict["image_proj"])
ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
ip_layers.load_state_dict(state_dict["ip_adapter"]) ip_layers.load_state_dict(state_dict["ip_adapter"])
if self.config.train_image_encoder and 'image_encoder' in state_dict:
self.image_encoder.load_state_dict(state_dict["image_encoder"])
# def load_state_dict(self, state_dict: Union[OrderedDict, dict]): # def load_state_dict(self, state_dict: Union[OrderedDict, dict]):
# self.load_ip_adapter(state_dict) # self.load_ip_adapter(state_dict)
@@ -257,6 +253,8 @@ class IPAdapter(torch.nn.Module):
state_dict = OrderedDict() state_dict = OrderedDict()
state_dict["image_proj"] = self.image_proj_model.state_dict() state_dict["image_proj"] = self.image_proj_model.state_dict()
state_dict["ip_adapter"] = self.adapter_modules.state_dict() state_dict["ip_adapter"] = self.adapter_modules.state_dict()
if self.config.train_image_encoder:
state_dict["image_encoder"] = self.image_encoder.state_dict()
return state_dict return state_dict
def get_scale(self): def get_scale(self):
@@ -281,37 +279,43 @@ class IPAdapter(torch.nn.Module):
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
return clip_image_embeds return clip_image_embeds
@torch.no_grad() def get_clip_image_embeds_from_tensors(self, tensors_0_1: torch.Tensor, drop=False,
def get_clip_image_embeds_from_tensors(self, tensors_0_1: torch.Tensor, drop=False) -> torch.Tensor: is_training=False) -> torch.Tensor:
# tensors should be 0-1 with torch.no_grad():
# todo: add support for sdxl # tensors should be 0-1
if tensors_0_1.ndim == 3: # todo: add support for sdxl
tensors_0_1 = tensors_0_1.unsqueeze(0) if tensors_0_1.ndim == 3:
# training tensors are 0 - 1 tensors_0_1 = tensors_0_1.unsqueeze(0)
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16) # training tensors are 0 - 1
# if images are out of this range throw error tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3: # if images are out of this range throw error
raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format( if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
tensors_0_1.min(), tensors_0_1.max() raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
)) tensors_0_1.min(), tensors_0_1.max()
))
clip_image = self.clip_image_processor( clip_image = self.clip_image_processor(
images=tensors_0_1, images=tensors_0_1,
return_tensors="pt", return_tensors="pt",
do_resize=True, do_resize=True,
do_rescale=False, do_rescale=False,
).pixel_values ).pixel_values
clip_image = clip_image.to(self.device, dtype=torch.float16).detach() clip_image = clip_image.to(self.device, dtype=torch.float16).detach()
if drop: if drop:
clip_image = clip_image * 0 clip_image = clip_image * 0
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] with torch.set_grad_enabled(is_training):
if is_training:
self.image_encoder.train()
else:
self.image_encoder.eval()
clip_output = self.image_encoder(clip_image, output_hidden_states=True)
clip_image_embeds = clip_output.hidden_states[-2]
return clip_image_embeds return clip_image_embeds
# use drop for prompt dropout, or negatives # use drop for prompt dropout, or negatives
def forward(self, embeddings: PromptEmbeds, clip_image_embeds: torch.Tensor) -> PromptEmbeds: def forward(self, embeddings: PromptEmbeds, clip_image_embeds: torch.Tensor) -> PromptEmbeds:
clip_image_embeds = clip_image_embeds.detach()
clip_image_embeds = clip_image_embeds.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype)) clip_image_embeds = clip_image_embeds.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
image_prompt_embeds = self.image_proj_model(clip_image_embeds.detach()) image_prompt_embeds = self.image_proj_model(clip_image_embeds)
embeddings.text_embeds = torch.cat([embeddings.text_embeds, image_prompt_embeds], dim=1) embeddings.text_embeds = torch.cat([embeddings.text_embeds, image_prompt_embeds], dim=1)
return embeddings return embeddings
@@ -319,7 +323,14 @@ class IPAdapter(torch.nn.Module):
for attn_processor in self.adapter_modules: for attn_processor in self.adapter_modules:
yield from attn_processor.parameters(recurse) yield from attn_processor.parameters(recurse)
yield from self.image_proj_model.parameters(recurse) yield from self.image_proj_model.parameters(recurse)
if self.config.train_image_encoder:
yield from self.image_encoder.parameters(recurse)
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict) self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict)
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict) self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict)
if self.config.train_image_encoder and 'image_encoder' in state_dict:
self.image_encoder.load_state_dict(state_dict["image_encoder"], strict=strict)
def enable_gradient_checkpointing(self):
self.image_encoder.gradient_checkpointing = True