Added Vision Languate Adapter usage for pixtral vd adapter

This commit is contained in:
Jaret Burkett
2024-09-29 19:39:56 -06:00
parent b4f64de4c2
commit f05224970f
2 changed files with 41 additions and 18 deletions

View File

@@ -744,6 +744,7 @@ class CustomAdapter(torch.nn.Module):
batch_size=1,
) -> PromptEmbeds:
if self.adapter_type == 'ilora' or self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter':
skip_unconditional = self.sd_ref().is_flux
if tensors_0_1 is None:
tensors_0_1 = self.get_empty_clip_image(batch_size)
has_been_preprocessed = True
@@ -797,7 +798,7 @@ class CustomAdapter(torch.nn.Module):
batch_size = clip_image.shape[0]
if self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter':
if (self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter') and not skip_unconditional:
# add an unconditional so we can save it
unconditional = self.get_empty_clip_image(batch_size, shape=clip_image.shape).to(
clip_image.device, dtype=clip_image.dtype
@@ -895,7 +896,10 @@ class CustomAdapter(torch.nn.Module):
# save them to the conditional and unconditional
try:
self.unconditional_embeds, self.conditional_embeds = clip_image_embeds.chunk(2, dim=0)
if skip_unconditional:
self.unconditional_embeds, self.conditional_embeds = None, clip_image_embeds
else:
self.unconditional_embeds, self.conditional_embeds = clip_image_embeds.chunk(2, dim=0)
except ValueError:
raise ValueError(f"could not split the clip image embeds into 2. Got shape: {clip_image_embeds.shape}")

View File

@@ -9,7 +9,7 @@ from collections import OrderedDict
from diffusers import Transformer2DModel, FluxTransformer2DModel
from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPVisionModelWithProjection
from toolkit.models.pixtral_vision import PixtralVisionEncoder, PixtralVisionImagePreprocessor
from toolkit.models.pixtral_vision import PixtralVisionEncoder, PixtralVisionImagePreprocessor, VisionLanguageAdapter
from toolkit.config_modules import AdapterConfig
from toolkit.paths import REPOS_ROOT
@@ -462,11 +462,23 @@ class VisionDirectAdapter(torch.nn.Module):
self.sd_ref: weakref.ref = weakref.ref(sd)
self.config: AdapterConfig = adapter.config
self.vision_model_ref: weakref.ref = weakref.ref(vision_model)
self.resampler = None
is_pixtral = self.config.image_encoder_arch == "pixtral"
if adapter.config.clip_layer == "image_embeds":
self.token_size = vision_model.config.projection_dim
else:
self.token_size = vision_model.config.hidden_size
self.mid_size = self.token_size
# if pixtral, use cross attn dim for more sparse representation
if is_pixtral:
if is_flux:
hidden_size = 3072
else:
hidden_size = sd.unet.config['cross_attention_dim']
self.mid_size = hidden_size
# init adapter modules
attn_procs = {}
@@ -487,9 +499,10 @@ class VisionDirectAdapter(torch.nn.Module):
for i, module in transformer.transformer_blocks.named_children():
attn_processor_keys.append(f"transformer_blocks.{i}.attn")
# single transformer blocks do not have cross attn, but we will do them anyway
for i, module in transformer.single_transformer_blocks.named_children():
attn_processor_keys.append(f"single_transformer_blocks.{i}.attn")
if not self.config.flux_only_double:
# single transformer blocks do not have cross attn, but we will do them anyway
for i, module in transformer.single_transformer_blocks.named_children():
attn_processor_keys.append(f"single_transformer_blocks.{i}.attn")
else:
attn_processor_keys = list(sd.unet.attn_processors.keys())
@@ -532,27 +545,27 @@ class VisionDirectAdapter(torch.nn.Module):
to_v_adapter = unet_sd[layer_name + ".to_v.weight"]
# add zero padding to the adapter
if to_k_adapter.shape[1] < self.token_size:
if to_k_adapter.shape[1] < self.mid_size:
to_k_adapter = torch.cat([
to_k_adapter,
torch.randn(to_k_adapter.shape[0], self.token_size - to_k_adapter.shape[1]).to(
torch.randn(to_k_adapter.shape[0], self.mid_size - to_k_adapter.shape[1]).to(
to_k_adapter.device, dtype=to_k_adapter.dtype) * 0.01
],
dim=1
)
to_v_adapter = torch.cat([
to_v_adapter,
torch.randn(to_v_adapter.shape[0], self.token_size - to_v_adapter.shape[1]).to(
torch.randn(to_v_adapter.shape[0], self.mid_size - to_v_adapter.shape[1]).to(
to_k_adapter.device, dtype=to_k_adapter.dtype) * 0.01
],
dim=1
)
elif to_k_adapter.shape[1] > self.token_size:
to_k_adapter = to_k_adapter[:, :self.token_size]
to_v_adapter = to_v_adapter[:, :self.token_size]
elif to_k_adapter.shape[1] > self.mid_size:
to_k_adapter = to_k_adapter[:, :self.mid_size]
to_v_adapter = to_v_adapter[:, :self.mid_size]
# if is_pixart:
# to_k_bias = to_k_bias[:self.token_size]
# to_v_bias = to_v_bias[:self.token_size]
# to_k_bias = to_k_bias[:self.mid_size]
# to_v_bias = to_v_bias[:self.mid_size]
else:
to_k_adapter = to_k_adapter
to_v_adapter = to_v_adapter
@@ -574,7 +587,7 @@ class VisionDirectAdapter(torch.nn.Module):
cross_attention_dim=cross_attention_dim,
scale=1.0,
adapter=self,
adapter_hidden_size=self.token_size,
adapter_hidden_size=self.mid_size,
has_bias=False,
block_idx=current_idx
)
@@ -584,7 +597,7 @@ class VisionDirectAdapter(torch.nn.Module):
cross_attention_dim=cross_attention_dim,
scale=1.0,
adapter=self,
adapter_hidden_size=self.token_size,
adapter_hidden_size=self.mid_size,
has_bias=False,
)
current_idx += 1
@@ -655,9 +668,15 @@ class VisionDirectAdapter(torch.nn.Module):
self.resampler = MLPR(
in_dim=self.token_size,
in_channels=max_seq_len,
out_dim=self.token_size,
out_dim=self.mid_size,
out_channels=self.config.num_tokens,
)
elif self.config.image_encoder_arch == "pixtral":
self.resampler = VisionLanguageAdapter(
in_dim=self.token_size,
out_dim=self.mid_size,
)
def state_dict(self, destination=None, prefix='', keep_vars=False):
if self.config.train_scaler:
@@ -678,7 +697,7 @@ class VisionDirectAdapter(torch.nn.Module):
# todo remove this when we have a real solution
if self.block_scaler is not None and self.block_scaler.dtype != torch.float32:
self.block_scaler.data = self.block_scaler.data.to(torch.float32)
if self.config.num_tokens is not None:
if self.resampler is not None:
input = self.resampler(input)
return input