Bug fixes. Added IP adapter training for Pixart

This commit is contained in:
Jaret Burkett
2024-02-17 10:06:57 -07:00
parent 93b52932c1
commit 2478554c95
4 changed files with 278 additions and 49 deletions

View File

@@ -4,6 +4,7 @@ import torch
import sys
from PIL import Image
from diffusers import Transformer2DModel
from torch.nn import Parameter
from torch.nn.modules.module import T
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
@@ -79,6 +80,10 @@ class CustomIPAttentionProcessor(IPAttnProcessor2_0):
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if is_active:
# since we are removing tokens, we need to adjust the sequence length
sequence_length = sequence_length - self.num_tokens
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
@@ -90,6 +95,9 @@ class CustomIPAttentionProcessor(IPAttnProcessor2_0):
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
# will be none if disabled
if not is_active:
ip_hidden_states = None
@@ -120,9 +128,13 @@ class CustomIPAttentionProcessor(IPAttnProcessor2_0):
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
try:
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
except Exception as e:
print(e)
raise e
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
@@ -235,7 +247,7 @@ class IPAdapter(torch.nn.Module):
print(f"could not load image processor from {adapter_config.image_encoder_path}")
self.clip_image_processor = ConvNextImageProcessor(
size=512,
image_mean=[0.485,0.456,0.406],
image_mean=[0.485, 0.456, 0.406],
image_std=[0.229, 0.224, 0.225],
)
self.image_encoder = ConvNextV2ForImageClassification.from_pretrained(
@@ -299,6 +311,7 @@ class IPAdapter(torch.nn.Module):
raise ValueError(f"unknown image processor size: {self.clip_image_processor.size}")
self.current_scale = 1.0
self.is_active = True
is_pixart = sd.is_pixart
if adapter_config.type == 'ip':
# ip-adapter
image_proj_model = ImageProjModel(
@@ -310,14 +323,22 @@ class IPAdapter(torch.nn.Module):
heads = 12 if not sd.is_xl else 20
dim = sd.unet.config['cross_attention_dim'] if not sd.is_xl else 1280
embedding_dim = self.image_encoder.config.hidden_size if not self.config.image_encoder_arch == "convnext" else \
self.image_encoder.config.hidden_sizes[-1]
self.image_encoder.config.hidden_sizes[-1]
image_encoder_state_dict = self.image_encoder.state_dict()
# max_seq_len = CLIP tokens + CLS token
max_seq_len = 257
if "vision_model.embeddings.position_embedding.weight" in image_encoder_state_dict:
# clip
max_seq_len = int(image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0])
max_seq_len = int(
image_encoder_state_dict["vision_model.embeddings.position_embedding.weight"].shape[0])
output_dim = sd.unet.config['cross_attention_dim']
if is_pixart:
heads = 20
dim = 4096
output_dim = 4096
# ip-adapter-plus
image_proj_model = Resampler(
@@ -328,7 +349,7 @@ class IPAdapter(torch.nn.Module):
num_queries=self.config.num_tokens if self.config.num_tokens > 0 else max_seq_len,
embedding_dim=embedding_dim,
max_seq_len=max_seq_len,
output_dim=sd.unet.config['cross_attention_dim'],
output_dim=output_dim,
ff_mult=4
)
elif adapter_config.type == 'ipz':
@@ -373,8 +394,21 @@ class IPAdapter(torch.nn.Module):
# init adapter modules
attn_procs = {}
unet_sd = sd.unet.state_dict()
for name in sd.unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else sd.unet.config['cross_attention_dim']
attn_processor_keys = []
if is_pixart:
transformer: Transformer2DModel = sd.unet
for i, module in transformer.transformer_blocks.named_children():
attn_processor_keys.append(f"transformer_blocks.{i}.attn1")
# cross attention
attn_processor_keys.append(f"transformer_blocks.{i}.attn2")
else:
attn_processor_keys = list(sd.unet.attn_processors.keys())
for name in attn_processor_keys:
cross_attention_dim = None if name.endswith("attn1.processor") or name.endswith("attn.1") else \
sd.unet.config['cross_attention_dim']
if name.startswith("mid_block"):
hidden_size = sd.unet.config['block_out_channels'][-1]
elif name.startswith("up_blocks"):
@@ -383,6 +417,8 @@ class IPAdapter(torch.nn.Module):
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = sd.unet.config['block_out_channels'][block_id]
elif name.startswith("transformer"):
hidden_size = sd.unet.config['cross_attention_dim']
else:
# they didnt have this, but would lead to undefined below
raise ValueError(f"unknown attn processor name: {name}")
@@ -402,14 +438,35 @@ class IPAdapter(torch.nn.Module):
num_tokens=self.config.num_tokens,
adapter=self
)
if self.sd_ref().is_pixart:
# pixart is much more sensitive
weights = {
"to_k_ip.weight": weights["to_k_ip.weight"] * 0.01,
"to_v_ip.weight": weights["to_v_ip.weight"] * 0.01,
}
attn_procs[name].load_state_dict(weights)
sd.unet.set_attn_processor(attn_procs)
adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())
if self.sd_ref().is_pixart:
# we have to set them ourselves
transformer: Transformer2DModel = sd.unet
for i, module in transformer.transformer_blocks.named_children():
module.attn1.processor = attn_procs[f"transformer_blocks.{i}.attn1"]
module.attn2.processor = attn_procs[f"transformer_blocks.{i}.attn2"]
self.adapter_modules = torch.nn.ModuleList(
[
transformer.transformer_blocks[i].attn1.processor for i in
range(len(transformer.transformer_blocks))
] + [
transformer.transformer_blocks[i].attn2.processor for i in
range(len(transformer.transformer_blocks))
])
else:
sd.unet.set_attn_processor(attn_procs)
self.adapter_modules = torch.nn.ModuleList(sd.unet.attn_processors.values())
sd.adapter = self
self.unet_ref: weakref.ref = weakref.ref(sd.unet)
self.image_proj_model = image_proj_model
self.adapter_modules = adapter_modules
# load the weights if we have some
if self.config.name_or_path:
loaded_state_dict = load_ip_adapter_model(
@@ -473,9 +530,10 @@ class IPAdapter(torch.nn.Module):
def set_scale(self, scale):
self.current_scale = scale
for attn_processor in self.sd_ref().unet.attn_processors.values():
if isinstance(attn_processor, CustomIPAttentionProcessor):
attn_processor.scale = scale
if not self.sd_ref().is_pixart:
for attn_processor in self.sd_ref().unet.attn_processors.values():
if isinstance(attn_processor, CustomIPAttentionProcessor):
attn_processor.scale = scale
# @torch.no_grad()
# def get_clip_image_embeds_from_pil(self, pil_image: Union[Image.Image, List[Image.Image]],
@@ -554,7 +612,7 @@ class IPAdapter(torch.nn.Module):
if self.clip_noise_zero:
tensors_0_1 = torch.rand_like(tensors_0_1).detach()
noise_scale = torch.rand([tensors_0_1.shape[0], 1, 1, 1], device=self.device,
dtype=get_torch_dtype(self.sd_ref().dtype))
dtype=get_torch_dtype(self.sd_ref().dtype))
tensors_0_1 = tensors_0_1 * noise_scale
else:
tensors_0_1 = torch.zeros_like(tensors_0_1).detach()
@@ -675,7 +733,6 @@ class IPAdapter(torch.nn.Module):
embeddings.text_embeds = torch.cat([embeddings.text_embeds, image_prompt_embeds], dim=1)
return embeddings
def train(self: T, mode: bool = True) -> T:
if self.config.train_image_encoder:
self.image_encoder.train(mode)
@@ -721,18 +778,22 @@ class IPAdapter(torch.nn.Module):
raise ValueError(f"unknown shape: {current_shape}")
except RuntimeError as e:
print(e)
print(f"could not merge in {key}: {list(current_shape)} <<< {list(new_shape)}. Trying other way")
print(
f"could not merge in {key}: {list(current_shape)} <<< {list(new_shape)}. Trying other way")
if len(current_shape) == 1:
current_img_proj_state_dict[key][:current_shape[0]] = value[:current_shape[0]]
elif len(current_shape) == 2:
current_img_proj_state_dict[key][:current_shape[0], :current_shape[1]] = value[:current_shape[0], :current_shape[1]]
current_img_proj_state_dict[key][:current_shape[0], :current_shape[1]] = value[
:current_shape[0],
:current_shape[1]]
elif len(current_shape) == 3:
current_img_proj_state_dict[key][:current_shape[0], :current_shape[1], :current_shape[2]] = value[:current_shape[0], :current_shape[1], :current_shape[2]]
current_img_proj_state_dict[key][:current_shape[0], :current_shape[1],
:current_shape[2]] = value[:current_shape[0], :current_shape[1], :current_shape[2]]
elif len(current_shape) == 4:
current_img_proj_state_dict[key][:current_shape[0], :current_shape[1], :current_shape[2],
:current_shape[3]] = value[:current_shape[0], :current_shape[1], :current_shape[2],
:current_shape[3]]
:current_shape[3]]
else:
raise ValueError(f"unknown shape: {current_shape}")
print(f"Force merged in {key}: {list(current_shape)} <<< {list(new_shape)}")
@@ -763,16 +824,24 @@ class IPAdapter(torch.nn.Module):
print(f"Force merged in {key}: {list(current_shape)} <<< {list(new_shape)}")
except RuntimeError as e:
print(e)
print(f"could not merge in {key}: {list(current_shape)} <<< {list(new_shape)}. Trying other way")
print(
f"could not merge in {key}: {list(current_shape)} <<< {list(new_shape)}. Trying other way")
if(len(current_shape) == 1):
if (len(current_shape) == 1):
current_ip_adapter_state_dict[key][:current_shape[0]] = value[:current_shape[0]]
elif(len(current_shape) == 2):
current_ip_adapter_state_dict[key][:current_shape[0], :current_shape[1]] = value[:current_shape[0], :current_shape[1]]
elif(len(current_shape) == 3):
current_ip_adapter_state_dict[key][:current_shape[0], :current_shape[1], :current_shape[2]] = value[:current_shape[0], :current_shape[1], :current_shape[2]]
elif(len(current_shape) == 4):
current_ip_adapter_state_dict[key][:current_shape[0], :current_shape[1], :current_shape[2], :current_shape[3]] = value[:current_shape[0], :current_shape[1], :current_shape[2], :current_shape[3]]
elif (len(current_shape) == 2):
current_ip_adapter_state_dict[key][:current_shape[0], :current_shape[1]] = value[
:current_shape[
0],
:current_shape[
1]]
elif (len(current_shape) == 3):
current_ip_adapter_state_dict[key][:current_shape[0], :current_shape[1],
:current_shape[2]] = value[:current_shape[0], :current_shape[1], :current_shape[2]]
elif (len(current_shape) == 4):
current_ip_adapter_state_dict[key][:current_shape[0], :current_shape[1], :current_shape[2],
:current_shape[3]] = value[:current_shape[0], :current_shape[1], :current_shape[2],
:current_shape[3]]
else:
raise ValueError(f"unknown shape: {current_shape}")
print(f"Force merged in {key}: {list(current_shape)} <<< {list(new_shape)}")
@@ -781,7 +850,6 @@ class IPAdapter(torch.nn.Module):
current_ip_adapter_state_dict[key] = value
self.adapter_modules.load_state_dict(current_ip_adapter_state_dict)
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
strict = False
if 'ip_adapter' in state_dict:
@@ -801,7 +869,6 @@ class IPAdapter(torch.nn.Module):
# we are loading pure clip weights.
self.image_encoder.load_state_dict(state_dict, strict=strict)
def enable_gradient_checkpointing(self):
if hasattr(self.image_encoder, "enable_gradient_checkpointing"):
self.image_encoder.enable_gradient_checkpointing()