mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Bug fixes. Added IP adapter training for Pixart
This commit is contained in:
@@ -4,6 +4,7 @@ import torch
|
||||
import sys
|
||||
|
||||
from PIL import Image
|
||||
from diffusers import Transformer2DModel
|
||||
from torch.nn import Parameter
|
||||
from torch.nn.modules.module import T
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
@@ -79,6 +80,10 @@ class CustomIPAttentionProcessor(IPAttnProcessor2_0):
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if is_active:
|
||||
# since we are removing tokens, we need to adjust the sequence length
|
||||
sequence_length = sequence_length - self.num_tokens
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
@@ -90,6 +95,9 @@ class CustomIPAttentionProcessor(IPAttnProcessor2_0):
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
# will be none if disabled
|
||||
if not is_active:
|
||||
ip_hidden_states = None
|
||||
@@ -120,9 +128,13 @@ class CustomIPAttentionProcessor(IPAttnProcessor2_0):
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
try:
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise e
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
@@ -235,7 +247,7 @@ class IPAdapter(torch.nn.Module):
|
||||
print(f"could not load image processor from {adapter_config.image_encoder_path}")
|
||||
self.clip_image_processor = ConvNextImageProcessor(
|
||||
size=512,
|
||||
image_mean=[0.485,0.456,0.406],
|
||||
image_mean=[0.485, 0.456, 0.406],
|
||||
image_std=[0.229, 0.224, 0.225],
|
||||
)
|
||||
self.image_encoder = ConvNextV2ForImageClassification.from_pretrained(
|
||||
@@ -299,6 +311,7 @@ class IPAdapter(torch.nn.Module):
|
||||
raise ValueError(f"unknown image processor size: {self.clip_image_processor.size}")
|
||||
self.current_scale = 1.0
|
||||
self.is_active = True
|
||||
is_pixart = sd.is_pixart
|
||||
if adapter_config.type == 'ip':
|
||||
# ip-adapter
|
||||
image_proj_model = ImageProjModel(
|
||||
@@ -310,14 +323,22 @@ class IPAdapter(torch.nn.Module):
|
||||
heads = 12 if not sd.is_xl else 20
|
||||
dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280
|
||||
embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch == "convnext" else \
|
||||
self.image_encoder.config.hidden_sizes[-1]
|
||||
self.image_encoder.config.hidden_sizes[-1]
|
||||
|
||||
image_encoder_state_dict = self.image_encoder.state_dict()
|
||||
# max_seq_len = CLIP tokens + CLS token
|
||||
max_seq_len = 257
|
||||
if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict:
|
||||
# clip
|
||||
max_seq_len = int(image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0])
|
||||
max_seq_len = int(
|
||||
image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0])
|
||||
|
||||
output_dim = sd.unet.config['cross_attention_dim']
|
||||
|
||||
if is_pixart:
|
||||
heads = 20
|
||||
dim = 4096
|
||||
output_dim = 4096
|
||||
|
||||
# ip-adapter-plus
|
||||
image_proj_model = Resampler(
|
||||
@@ -328,7 +349,7 @@ class IPAdapter(torch.nn.Module):
|
||||
num_queries=self.config.num_tokens if self.config.num_tokens > 0 else max_seq_len,
|
||||
embedding_dim=embedding_dim,
|
||||
max_seq_len=max_seq_len,
|
||||
output_dim=sd.unet.config['cross_attention_dim'],
|
||||
output_dim=output_dim,
|
||||
ff_mult=4
|
||||
)
|
||||
elif adapter_config.type == 'ipz':
|
||||
@@ -373,8 +394,21 @@ class IPAdapter(torch.nn.Module):
|
||||
# init adapter modules
|
||||
attn_procs = {}
|
||||
unet_sd = sd.unet.state_dict()
|
||||
for name in sd.unet.attn_processors.keys():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else sd.unet.config['cross_attention_dim']
|
||||
attn_processor_keys = []
|
||||
if is_pixart:
|
||||
transformer: Transformer2DModel = sd.unet
|
||||
for i, module in transformer.transformer_blocks.named_children():
|
||||
attn_processor_keys.append(f"transformer_blocks.{i}.attn1")
|
||||
|
||||
# cross attention
|
||||
attn_processor_keys.append(f"transformer_blocks.{i}.attn2")
|
||||
|
||||
else:
|
||||
attn_processor_keys = list(sd.unet.attn_processors.keys())
|
||||
|
||||
for name in attn_processor_keys:
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") else \
|
||||
sd.unet.config['cross_attention_dim']
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = sd.unet.config['block_out_channels'][-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
@@ -383,6 +417,8 @@ class IPAdapter(torch.nn.Module):
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = sd.unet.config['block_out_channels'][block_id]
|
||||
elif name.startswith("transformer"):
|
||||
hidden_size = sd.unet.config['cross_attention_dim']
|
||||
else:
|
||||
# they didnt have this, but would lead to undefined below
|
||||
raise ValueError(f"unknown attn processor name: {name}")
|
||||
@@ -402,14 +438,35 @@ class IPAdapter(torch.nn.Module):
|
||||
num_tokens=self.config.num_tokens,
|
||||
adapter=self
|
||||
)
|
||||
if self.sd_ref().is_pixart:
|
||||
# pixart is much more sensitive
|
||||
weights = {
|
||||
"to_k_ip.weight": weights["to_k_ip.weight"] * 0.01,
|
||||
"to_v_ip.weight": weights["to_v_ip.weight"] * 0.01,
|
||||
}
|
||||
|
||||
attn_procs[name].load_state_dict(weights)
|
||||
sd.unet.set_attn_processor(attn_procs)
|
||||
adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())
|
||||
if self.sd_ref().is_pixart:
|
||||
# we have to set them ourselves
|
||||
transformer: Transformer2DModel = sd.unet
|
||||
for i, module in transformer.transformer_blocks.named_children():
|
||||
module.attn1.processor = attn_procs[f"transformer_blocks.{i}.attn1"]
|
||||
module.attn2.processor = attn_procs[f"transformer_blocks.{i}.attn2"]
|
||||
self.adapter_modules = torch.nn.ModuleList(
|
||||
[
|
||||
transformer.transformer_blocks[i].attn1.processor for i in
|
||||
range(len(transformer.transformer_blocks))
|
||||
] + [
|
||||
transformer.transformer_blocks[i].attn2.processor for i in
|
||||
range(len(transformer.transformer_blocks))
|
||||
])
|
||||
else:
|
||||
sd.unet.set_attn_processor(attn_procs)
|
||||
self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())
|
||||
|
||||
sd.adapter = self
|
||||
self.unet_ref: weakref.ref = weakref.ref(sd.unet)
|
||||
self.image_proj_model = image_proj_model
|
||||
self.adapter_modules = adapter_modules
|
||||
# load the weights if we have some
|
||||
if self.config.name_or_path:
|
||||
loaded_state_dict = load_ip_adapter_model(
|
||||
@@ -473,9 +530,10 @@ class IPAdapter(torch.nn.Module):
|
||||
|
||||
def set_scale(self, scale):
|
||||
self.current_scale = scale
|
||||
for attn_processor in self.sd_ref().unet.attn_processors.values():
|
||||
if isinstance(attn_processor, CustomIPAttentionProcessor):
|
||||
attn_processor.scale = scale
|
||||
if not self.sd_ref().is_pixart:
|
||||
for attn_processor in self.sd_ref().unet.attn_processors.values():
|
||||
if isinstance(attn_processor, CustomIPAttentionProcessor):
|
||||
attn_processor.scale = scale
|
||||
|
||||
# @torch.no_grad()
|
||||
# def get_clip_image_embeds_from_pil(self, pil_image: Union[Image.Image, List[Image.Image]],
|
||||
@@ -554,7 +612,7 @@ class IPAdapter(torch.nn.Module):
|
||||
if self.clip_noise_zero:
|
||||
tensors_0_1 = torch.rand_like(tensors_0_1).detach()
|
||||
noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device,
|
||||
dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
dtype=get_torch_dtype(self.sd_ref().dtype))
|
||||
tensors_0_1 = tensors_0_1 * noise_scale
|
||||
else:
|
||||
tensors_0_1 = torch.zeros_like(tensors_0_1).detach()
|
||||
@@ -675,7 +733,6 @@ class IPAdapter(torch.nn.Module):
|
||||
embeddings.text_embeds = torch.cat([embeddings.text_embeds, image_prompt_embeds], dim=1)
|
||||
return embeddings
|
||||
|
||||
|
||||
def train(self: T, mode: bool = True) -> T:
|
||||
if self.config.train_image_encoder:
|
||||
self.image_encoder.train(mode)
|
||||
@@ -721,18 +778,22 @@ class IPAdapter(torch.nn.Module):
|
||||
raise ValueError(f"unknown shape: {current_shape}")
|
||||
except RuntimeError as e:
|
||||
print(e)
|
||||
print(f"could not merge in {key}: {list(current_shape)} <<< {list(new_shape)}. Trying other way")
|
||||
print(
|
||||
f"could not merge in {key}: {list(current_shape)} <<< {list(new_shape)}. Trying other way")
|
||||
|
||||
if len(current_shape) == 1:
|
||||
current_img_proj_state_dict[key][:current_shape[0]] = value[:current_shape[0]]
|
||||
elif len(current_shape) == 2:
|
||||
current_img_proj_state_dict[key][:current_shape[0], :current_shape[1]] = value[:current_shape[0], :current_shape[1]]
|
||||
current_img_proj_state_dict[key][:current_shape[0], :current_shape[1]] = value[
|
||||
:current_shape[0],
|
||||
:current_shape[1]]
|
||||
elif len(current_shape) == 3:
|
||||
current_img_proj_state_dict[key][:current_shape[0], :current_shape[1], :current_shape[2]] = value[:current_shape[0], :current_shape[1], :current_shape[2]]
|
||||
current_img_proj_state_dict[key][:current_shape[0], :current_shape[1],
|
||||
:current_shape[2]] = value[:current_shape[0], :current_shape[1], :current_shape[2]]
|
||||
elif len(current_shape) == 4:
|
||||
current_img_proj_state_dict[key][:current_shape[0], :current_shape[1], :current_shape[2],
|
||||
:current_shape[3]] = value[:current_shape[0], :current_shape[1], :current_shape[2],
|
||||
:current_shape[3]]
|
||||
:current_shape[3]]
|
||||
else:
|
||||
raise ValueError(f"unknown shape: {current_shape}")
|
||||
print(f"Force merged in {key}: {list(current_shape)} <<< {list(new_shape)}")
|
||||
@@ -763,16 +824,24 @@ class IPAdapter(torch.nn.Module):
|
||||
print(f"Force merged in {key}: {list(current_shape)} <<< {list(new_shape)}")
|
||||
except RuntimeError as e:
|
||||
print(e)
|
||||
print(f"could not merge in {key}: {list(current_shape)} <<< {list(new_shape)}. Trying other way")
|
||||
print(
|
||||
f"could not merge in {key}: {list(current_shape)} <<< {list(new_shape)}. Trying other way")
|
||||
|
||||
if(len(current_shape) == 1):
|
||||
if (len(current_shape) == 1):
|
||||
current_ip_adapter_state_dict[key][:current_shape[0]] = value[:current_shape[0]]
|
||||
elif(len(current_shape) == 2):
|
||||
current_ip_adapter_state_dict[key][:current_shape[0], :current_shape[1]] = value[:current_shape[0], :current_shape[1]]
|
||||
elif(len(current_shape) == 3):
|
||||
current_ip_adapter_state_dict[key][:current_shape[0], :current_shape[1], :current_shape[2]] = value[:current_shape[0], :current_shape[1], :current_shape[2]]
|
||||
elif(len(current_shape) == 4):
|
||||
current_ip_adapter_state_dict[key][:current_shape[0], :current_shape[1], :current_shape[2], :current_shape[3]] = value[:current_shape[0], :current_shape[1], :current_shape[2], :current_shape[3]]
|
||||
elif (len(current_shape) == 2):
|
||||
current_ip_adapter_state_dict[key][:current_shape[0], :current_shape[1]] = value[
|
||||
:current_shape[
|
||||
0],
|
||||
:current_shape[
|
||||
1]]
|
||||
elif (len(current_shape) == 3):
|
||||
current_ip_adapter_state_dict[key][:current_shape[0], :current_shape[1],
|
||||
:current_shape[2]] = value[:current_shape[0], :current_shape[1], :current_shape[2]]
|
||||
elif (len(current_shape) == 4):
|
||||
current_ip_adapter_state_dict[key][:current_shape[0], :current_shape[1], :current_shape[2],
|
||||
:current_shape[3]] = value[:current_shape[0], :current_shape[1], :current_shape[2],
|
||||
:current_shape[3]]
|
||||
else:
|
||||
raise ValueError(f"unknown shape: {current_shape}")
|
||||
print(f"Force merged in {key}: {list(current_shape)} <<< {list(new_shape)}")
|
||||
@@ -781,7 +850,6 @@ class IPAdapter(torch.nn.Module):
|
||||
current_ip_adapter_state_dict[key] = value
|
||||
self.adapter_modules.load_state_dict(current_ip_adapter_state_dict)
|
||||
|
||||
|
||||
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
||||
strict = False
|
||||
if 'ip_adapter' in state_dict:
|
||||
@@ -801,7 +869,6 @@ class IPAdapter(torch.nn.Module):
|
||||
# we are loading pure clip weights.
|
||||
self.image_encoder.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
if hasattr(self.image_encoder, "enable_gradient_checkpointing"):
|
||||
self.image_encoder.enable_gradient_checkpointing()
|
||||
|
||||
Reference in New Issue
Block a user