mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 03:01:28 +00:00
Allow ip adapters to be much more variable in their creation
This commit is contained in:
@@ -787,25 +787,33 @@ class SDTrainer(BaseSDTrainProcess):
|
|||||||
|
|
||||||
if self.adapter and isinstance(self.adapter, IPAdapter):
|
if self.adapter and isinstance(self.adapter, IPAdapter):
|
||||||
with self.timer('encode_adapter_embeds'):
|
with self.timer('encode_adapter_embeds'):
|
||||||
with torch.no_grad():
|
if has_adapter_img:
|
||||||
if has_adapter_img:
|
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
|
||||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
|
adapter_images.detach().to(self.device_torch, dtype=dtype),
|
||||||
adapter_images.detach().to(self.device_torch, dtype=dtype))
|
is_training=True
|
||||||
elif is_reg:
|
)
|
||||||
# we will zero it out in the img embedder
|
elif is_reg:
|
||||||
adapter_img = torch.zeros(
|
# we will zero it out in the img embedder
|
||||||
(noisy_latents.shape[0], 3, 512, 512),
|
adapter_img = torch.zeros(
|
||||||
device=self.device_torch, dtype=dtype
|
(noisy_latents.shape[0], 3, 512, 512),
|
||||||
)
|
device=self.device_torch, dtype=dtype
|
||||||
# drop will zero it out
|
).detach()
|
||||||
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
|
# drop will zero it out
|
||||||
adapter_img, drop=True
|
conditional_clip_embeds = self.adapter.get_clip_image_embeds_from_tensors(
|
||||||
)
|
adapter_img,
|
||||||
else:
|
drop=True,
|
||||||
raise ValueError("Adapter images now must be loaded with dataloader or be a reg image")
|
is_training=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Adapter images now must be loaded with dataloader or be a reg image")
|
||||||
|
|
||||||
|
if not self.adapter_config.train_image_encoder:
|
||||||
|
# we are not training the image encoder, so we need to detach the embeds
|
||||||
|
conditional_clip_embeds = conditional_clip_embeds.detach()
|
||||||
|
|
||||||
|
|
||||||
with self.timer('encode_adapter'):
|
with self.timer('encode_adapter'):
|
||||||
conditional_embeds = self.adapter(conditional_embeds.detach(), conditional_clip_embeds.detach())
|
conditional_embeds = self.adapter(conditional_embeds.detach(), conditional_clip_embeds)
|
||||||
|
|
||||||
prior_pred = None
|
prior_pred = None
|
||||||
if (has_adapter_img and self.assistant_adapter and match_adapter_assist) or (self.do_prior_prediction and not is_reg):
|
if (has_adapter_img and self.assistant_adapter and match_adapter_assist) or (self.do_prior_prediction and not is_reg):
|
||||||
|
|||||||
@@ -870,6 +870,8 @@ class BaseSDTrainProcess(BaseTrainProcess):
|
|||||||
sd=self.sd,
|
sd=self.sd,
|
||||||
adapter_config=self.adapter_config,
|
adapter_config=self.adapter_config,
|
||||||
)
|
)
|
||||||
|
if self.train_config.gradient_checkpointing:
|
||||||
|
self.adapter.enable_gradient_checkpointing()
|
||||||
self.adapter.to(self.device_torch, dtype=dtype)
|
self.adapter.to(self.device_torch, dtype=dtype)
|
||||||
if latest_save_path is not None:
|
if latest_save_path is not None:
|
||||||
# load adapter from path
|
# load adapter from path
|
||||||
|
|||||||
@@ -142,6 +142,17 @@ class AdapterConfig:
|
|||||||
self.image_encoder_path: str = kwargs.get('image_encoder_path', None)
|
self.image_encoder_path: str = kwargs.get('image_encoder_path', None)
|
||||||
self.name_or_path = kwargs.get('name_or_path', None)
|
self.name_or_path = kwargs.get('name_or_path', None)
|
||||||
|
|
||||||
|
num_tokens = kwargs.get('num_tokens', None)
|
||||||
|
if num_tokens is None and self.type.startswith('ip'):
|
||||||
|
if self.type == 'ip+':
|
||||||
|
num_tokens = 16
|
||||||
|
elif self.type == 'ip':
|
||||||
|
num_tokens = 4
|
||||||
|
|
||||||
|
self.num_tokens: int = num_tokens
|
||||||
|
self.train_image_encoder: bool = kwargs.get('train_image_encoder', False)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingConfig:
|
class EmbeddingConfig:
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
|
|||||||
@@ -12,7 +12,8 @@ from toolkit.train_tools import get_torch_dtype
|
|||||||
sys.path.append(REPOS_ROOT)
|
sys.path.append(REPOS_ROOT)
|
||||||
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List
|
from typing import TYPE_CHECKING, Union, Iterator, Mapping, Any, Tuple, List
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, AttnProcessor2_0
|
from ipadapter.ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor, IPAttnProcessor2_0, \
|
||||||
|
AttnProcessor2_0
|
||||||
from ipadapter.ip_adapter.ip_adapter import ImageProjModel
|
from ipadapter.ip_adapter.ip_adapter import ImageProjModel
|
||||||
from ipadapter.ip_adapter.resampler import Resampler
|
from ipadapter.ip_adapter.resampler import Resampler
|
||||||
from toolkit.config_modules import AdapterConfig
|
from toolkit.config_modules import AdapterConfig
|
||||||
@@ -25,6 +26,7 @@ if TYPE_CHECKING:
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
CLIPImageProcessor,
|
CLIPImageProcessor,
|
||||||
CLIPVisionModelWithProjection,
|
CLIPVisionModelWithProjection,
|
||||||
|
CLIPVisionModel
|
||||||
)
|
)
|
||||||
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -151,9 +153,10 @@ class IPAdapter(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = adapter_config
|
self.config = adapter_config
|
||||||
self.sd_ref: weakref.ref = weakref.ref(sd)
|
self.sd_ref: weakref.ref = weakref.ref(sd)
|
||||||
self.clip_image_processor = CLIPImageProcessor()
|
self.clip_image_processor = CLIPImageProcessor.from_pretrained(adapter_config.image_encoder_path)
|
||||||
self.device = self.sd_ref().unet.device
|
self.device = self.sd_ref().unet.device
|
||||||
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(adapter_config.image_encoder_path)
|
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(adapter_config.image_encoder_path,
|
||||||
|
ignore_mismatched_sizes=True)
|
||||||
self.current_scale = 1.0
|
self.current_scale = 1.0
|
||||||
self.is_active = True
|
self.is_active = True
|
||||||
if adapter_config.type == 'ip':
|
if adapter_config.type == 'ip':
|
||||||
@@ -161,17 +164,16 @@ class IPAdapter(torch.nn.Module):
|
|||||||
image_proj_model = ImageProjModel(
|
image_proj_model = ImageProjModel(
|
||||||
cross_attention_dim=sd.unet.config['cross_attention_dim'],
|
cross_attention_dim=sd.unet.config['cross_attention_dim'],
|
||||||
clip_embeddings_dim=self.image_encoder.config.projection_dim,
|
clip_embeddings_dim=self.image_encoder.config.projection_dim,
|
||||||
clip_extra_context_tokens=4,
|
clip_extra_context_tokens=self.config.num_tokens, # usually 4
|
||||||
)
|
)
|
||||||
elif adapter_config.type == 'ip+':
|
elif adapter_config.type == 'ip+':
|
||||||
# ip-adapter-plus
|
# ip-adapter-plus
|
||||||
num_tokens = 16
|
|
||||||
image_proj_model = Resampler(
|
image_proj_model = Resampler(
|
||||||
dim=sd.unet.config['cross_attention_dim'],
|
dim=sd.unet.config['cross_attention_dim'],
|
||||||
depth=4,
|
depth=4,
|
||||||
dim_head=64,
|
dim_head=64,
|
||||||
heads=12,
|
heads=12,
|
||||||
num_queries=num_tokens,
|
num_queries=self.config.num_tokens, # usually 16
|
||||||
embedding_dim=self.image_encoder.config.hidden_size,
|
embedding_dim=self.image_encoder.config.hidden_size,
|
||||||
output_dim=sd.unet.config['cross_attention_dim'],
|
output_dim=sd.unet.config['cross_attention_dim'],
|
||||||
ff_mult=4
|
ff_mult=4
|
||||||
@@ -203,20 +205,12 @@ class IPAdapter(torch.nn.Module):
|
|||||||
"to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
|
"to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
|
||||||
"to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
|
"to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
|
||||||
}
|
}
|
||||||
if adapter_config.type == 'ip':
|
|
||||||
# ip-adapter
|
|
||||||
num_tokens = 4
|
|
||||||
elif adapter_config.type == 'ip+':
|
|
||||||
# ip-adapter-plus
|
|
||||||
num_tokens = 16
|
|
||||||
else:
|
|
||||||
raise ValueError(f"unknown adapter type: {adapter_config.type}")
|
|
||||||
|
|
||||||
attn_procs[name] = CustomIPAttentionProcessor(
|
attn_procs[name] = CustomIPAttentionProcessor(
|
||||||
hidden_size=hidden_size,
|
hidden_size=hidden_size,
|
||||||
cross_attention_dim=cross_attention_dim,
|
cross_attention_dim=cross_attention_dim,
|
||||||
scale=1.0,
|
scale=1.0,
|
||||||
num_tokens=num_tokens,
|
num_tokens=self.config.num_tokens,
|
||||||
adapter=self
|
adapter=self
|
||||||
)
|
)
|
||||||
attn_procs[name].load_state_dict(weights)
|
attn_procs[name].load_state_dict(weights)
|
||||||
@@ -249,6 +243,8 @@ class IPAdapter(torch.nn.Module):
|
|||||||
self.image_proj_model.load_state_dict(state_dict["image_proj"])
|
self.image_proj_model.load_state_dict(state_dict["image_proj"])
|
||||||
ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
|
ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
|
||||||
ip_layers.load_state_dict(state_dict["ip_adapter"])
|
ip_layers.load_state_dict(state_dict["ip_adapter"])
|
||||||
|
if self.config.train_image_encoder and 'image_encoder' in state_dict:
|
||||||
|
self.image_encoder.load_state_dict(state_dict["image_encoder"])
|
||||||
|
|
||||||
# def load_state_dict(self, state_dict: Union[OrderedDict, dict]):
|
# def load_state_dict(self, state_dict: Union[OrderedDict, dict]):
|
||||||
# self.load_ip_adapter(state_dict)
|
# self.load_ip_adapter(state_dict)
|
||||||
@@ -257,6 +253,8 @@ class IPAdapter(torch.nn.Module):
|
|||||||
state_dict = OrderedDict()
|
state_dict = OrderedDict()
|
||||||
state_dict["image_proj"] = self.image_proj_model.state_dict()
|
state_dict["image_proj"] = self.image_proj_model.state_dict()
|
||||||
state_dict["ip_adapter"] = self.adapter_modules.state_dict()
|
state_dict["ip_adapter"] = self.adapter_modules.state_dict()
|
||||||
|
if self.config.train_image_encoder:
|
||||||
|
state_dict["image_encoder"] = self.image_encoder.state_dict()
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
||||||
def get_scale(self):
|
def get_scale(self):
|
||||||
@@ -281,37 +279,43 @@ class IPAdapter(torch.nn.Module):
|
|||||||
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
||||||
return clip_image_embeds
|
return clip_image_embeds
|
||||||
|
|
||||||
@torch.no_grad()
|
def get_clip_image_embeds_from_tensors(self, tensors_0_1: torch.Tensor, drop=False,
|
||||||
def get_clip_image_embeds_from_tensors(self, tensors_0_1: torch.Tensor, drop=False) -> torch.Tensor:
|
is_training=False) -> torch.Tensor:
|
||||||
# tensors should be 0-1
|
with torch.no_grad():
|
||||||
# todo: add support for sdxl
|
# tensors should be 0-1
|
||||||
if tensors_0_1.ndim == 3:
|
# todo: add support for sdxl
|
||||||
tensors_0_1 = tensors_0_1.unsqueeze(0)
|
if tensors_0_1.ndim == 3:
|
||||||
# training tensors are 0 - 1
|
tensors_0_1 = tensors_0_1.unsqueeze(0)
|
||||||
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
|
# training tensors are 0 - 1
|
||||||
# if images are out of this range throw error
|
tensors_0_1 = tensors_0_1.to(self.device, dtype=torch.float16)
|
||||||
if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
|
# if images are out of this range throw error
|
||||||
raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
|
if tensors_0_1.min() < -0.3 or tensors_0_1.max() > 1.3:
|
||||||
tensors_0_1.min(), tensors_0_1.max()
|
raise ValueError("image tensor values must be between 0 and 1. Got min: {}, max: {}".format(
|
||||||
))
|
tensors_0_1.min(), tensors_0_1.max()
|
||||||
|
))
|
||||||
|
|
||||||
clip_image = self.clip_image_processor(
|
clip_image = self.clip_image_processor(
|
||||||
images=tensors_0_1,
|
images=tensors_0_1,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
do_resize=True,
|
do_resize=True,
|
||||||
do_rescale=False,
|
do_rescale=False,
|
||||||
).pixel_values
|
).pixel_values
|
||||||
clip_image = clip_image.to(self.device, dtype=torch.float16).detach()
|
clip_image = clip_image.to(self.device, dtype=torch.float16).detach()
|
||||||
if drop:
|
if drop:
|
||||||
clip_image = clip_image * 0
|
clip_image = clip_image * 0
|
||||||
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
with torch.set_grad_enabled(is_training):
|
||||||
|
if is_training:
|
||||||
|
self.image_encoder.train()
|
||||||
|
else:
|
||||||
|
self.image_encoder.eval()
|
||||||
|
clip_output = self.image_encoder(clip_image, output_hidden_states=True)
|
||||||
|
clip_image_embeds = clip_output.hidden_states[-2]
|
||||||
return clip_image_embeds
|
return clip_image_embeds
|
||||||
|
|
||||||
# use drop for prompt dropout, or negatives
|
# use drop for prompt dropout, or negatives
|
||||||
def forward(self, embeddings: PromptEmbeds, clip_image_embeds: torch.Tensor) -> PromptEmbeds:
|
def forward(self, embeddings: PromptEmbeds, clip_image_embeds: torch.Tensor) -> PromptEmbeds:
|
||||||
clip_image_embeds = clip_image_embeds.detach()
|
|
||||||
clip_image_embeds = clip_image_embeds.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
clip_image_embeds = clip_image_embeds.to(self.device, dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||||
image_prompt_embeds = self.image_proj_model(clip_image_embeds.detach())
|
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
|
||||||
embeddings.text_embeds = torch.cat([embeddings.text_embeds, image_prompt_embeds], dim=1)
|
embeddings.text_embeds = torch.cat([embeddings.text_embeds, image_prompt_embeds], dim=1)
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
@@ -319,7 +323,14 @@ class IPAdapter(torch.nn.Module):
|
|||||||
for attn_processor in self.adapter_modules:
|
for attn_processor in self.adapter_modules:
|
||||||
yield from attn_processor.parameters(recurse)
|
yield from attn_processor.parameters(recurse)
|
||||||
yield from self.image_proj_model.parameters(recurse)
|
yield from self.image_proj_model.parameters(recurse)
|
||||||
|
if self.config.train_image_encoder:
|
||||||
|
yield from self.image_encoder.parameters(recurse)
|
||||||
|
|
||||||
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
||||||
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict)
|
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict)
|
||||||
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict)
|
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=strict)
|
||||||
|
if self.config.train_image_encoder and 'image_encoder' in state_dict:
|
||||||
|
self.image_encoder.load_state_dict(state_dict["image_encoder"], strict=strict)
|
||||||
|
|
||||||
|
def enable_gradient_checkpointing(self):
|
||||||
|
self.image_encoder.gradient_checkpointing = True
|
||||||
|
|||||||
Reference in New Issue
Block a user