mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-22 21:33:59 +00:00
Added Vision Languate Adapter usage for pixtral vd adapter
This commit is contained in:
@@ -744,6 +744,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
batch_size=1,
|
||||
) -> PromptEmbeds:
|
||||
if self.adapter_type == 'ilora' or self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter':
|
||||
skip_unconditional = self.sd_ref().is_flux
|
||||
if tensors_0_1 is None:
|
||||
tensors_0_1 = self.get_empty_clip_image(batch_size)
|
||||
has_been_preprocessed = True
|
||||
@@ -797,7 +798,7 @@ class CustomAdapter(torch.nn.Module):
|
||||
|
||||
|
||||
batch_size = clip_image.shape[0]
|
||||
if self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter':
|
||||
if (self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter') and not skip_unconditional:
|
||||
# add an unconditional so we can save it
|
||||
unconditional = self.get_empty_clip_image(batch_size, shape=clip_image.shape).to(
|
||||
clip_image.device, dtype=clip_image.dtype
|
||||
@@ -895,7 +896,10 @@ class CustomAdapter(torch.nn.Module):
|
||||
|
||||
# save them to the conditional and unconditional
|
||||
try:
|
||||
self.unconditional_embeds, self.conditional_embeds = clip_image_embeds.chunk(2, dim=0)
|
||||
if skip_unconditional:
|
||||
self.unconditional_embeds, self.conditional_embeds = None, clip_image_embeds
|
||||
else:
|
||||
self.unconditional_embeds, self.conditional_embeds = clip_image_embeds.chunk(2, dim=0)
|
||||
except ValueError:
|
||||
raise ValueError(f"could not split the clip image embeds into 2. Got shape: {clip_image_embeds.shape}")
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from collections import OrderedDict
|
||||
|
||||
from diffusers import Transformer2DModel, FluxTransformer2DModel
|
||||
from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPVisionModelWithProjection
|
||||
from toolkit.models.pixtral_vision import PixtralVisionEncoder, PixtralVisionImagePreprocessor
|
||||
from toolkit.models.pixtral_vision import PixtralVisionEncoder, PixtralVisionImagePreprocessor, VisionLanguageAdapter
|
||||
|
||||
from toolkit.config_modules import AdapterConfig
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
@@ -462,11 +462,23 @@ class VisionDirectAdapter(torch.nn.Module):
|
||||
self.sd_ref: weakref.ref = weakref.ref(sd)
|
||||
self.config: AdapterConfig = adapter.config
|
||||
self.vision_model_ref: weakref.ref = weakref.ref(vision_model)
|
||||
self.resampler = None
|
||||
is_pixtral = self.config.image_encoder_arch == "pixtral"
|
||||
|
||||
if adapter.config.clip_layer == "image_embeds":
|
||||
self.token_size = vision_model.config.projection_dim
|
||||
else:
|
||||
self.token_size = vision_model.config.hidden_size
|
||||
|
||||
self.mid_size = self.token_size
|
||||
|
||||
# if pixtral, use cross attn dim for more sparse representation
|
||||
if is_pixtral:
|
||||
if is_flux:
|
||||
hidden_size = 3072
|
||||
else:
|
||||
hidden_size = sd.unet.config['cross_attention_dim']
|
||||
self.mid_size = hidden_size
|
||||
|
||||
# init adapter modules
|
||||
attn_procs = {}
|
||||
@@ -487,9 +499,10 @@ class VisionDirectAdapter(torch.nn.Module):
|
||||
for i, module in transformer.transformer_blocks.named_children():
|
||||
attn_processor_keys.append(f"transformer_blocks.{i}.attn")
|
||||
|
||||
# single transformer blocks do not have cross attn, but we will do them anyway
|
||||
for i, module in transformer.single_transformer_blocks.named_children():
|
||||
attn_processor_keys.append(f"single_transformer_blocks.{i}.attn")
|
||||
if not self.config.flux_only_double:
|
||||
# single transformer blocks do not have cross attn, but we will do them anyway
|
||||
for i, module in transformer.single_transformer_blocks.named_children():
|
||||
attn_processor_keys.append(f"single_transformer_blocks.{i}.attn")
|
||||
else:
|
||||
attn_processor_keys = list(sd.unet.attn_processors.keys())
|
||||
|
||||
@@ -532,27 +545,27 @@ class VisionDirectAdapter(torch.nn.Module):
|
||||
to_v_adapter = unet_sd[layer_name + ".to_v.weight"]
|
||||
|
||||
# add zero padding to the adapter
|
||||
if to_k_adapter.shape[1] < self.token_size:
|
||||
if to_k_adapter.shape[1] < self.mid_size:
|
||||
to_k_adapter = torch.cat([
|
||||
to_k_adapter,
|
||||
torch.randn(to_k_adapter.shape[0], self.token_size - to_k_adapter.shape[1]).to(
|
||||
torch.randn(to_k_adapter.shape[0], self.mid_size - to_k_adapter.shape[1]).to(
|
||||
to_k_adapter.device, dtype=to_k_adapter.dtype) * 0.01
|
||||
],
|
||||
dim=1
|
||||
)
|
||||
to_v_adapter = torch.cat([
|
||||
to_v_adapter,
|
||||
torch.randn(to_v_adapter.shape[0], self.token_size - to_v_adapter.shape[1]).to(
|
||||
torch.randn(to_v_adapter.shape[0], self.mid_size - to_v_adapter.shape[1]).to(
|
||||
to_k_adapter.device, dtype=to_k_adapter.dtype) * 0.01
|
||||
],
|
||||
dim=1
|
||||
)
|
||||
elif to_k_adapter.shape[1] > self.token_size:
|
||||
to_k_adapter = to_k_adapter[:, :self.token_size]
|
||||
to_v_adapter = to_v_adapter[:, :self.token_size]
|
||||
elif to_k_adapter.shape[1] > self.mid_size:
|
||||
to_k_adapter = to_k_adapter[:, :self.mid_size]
|
||||
to_v_adapter = to_v_adapter[:, :self.mid_size]
|
||||
# if is_pixart:
|
||||
# to_k_bias = to_k_bias[:self.token_size]
|
||||
# to_v_bias = to_v_bias[:self.token_size]
|
||||
# to_k_bias = to_k_bias[:self.mid_size]
|
||||
# to_v_bias = to_v_bias[:self.mid_size]
|
||||
else:
|
||||
to_k_adapter = to_k_adapter
|
||||
to_v_adapter = to_v_adapter
|
||||
@@ -574,7 +587,7 @@ class VisionDirectAdapter(torch.nn.Module):
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
scale=1.0,
|
||||
adapter=self,
|
||||
adapter_hidden_size=self.token_size,
|
||||
adapter_hidden_size=self.mid_size,
|
||||
has_bias=False,
|
||||
block_idx=current_idx
|
||||
)
|
||||
@@ -584,7 +597,7 @@ class VisionDirectAdapter(torch.nn.Module):
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
scale=1.0,
|
||||
adapter=self,
|
||||
adapter_hidden_size=self.token_size,
|
||||
adapter_hidden_size=self.mid_size,
|
||||
has_bias=False,
|
||||
)
|
||||
current_idx += 1
|
||||
@@ -655,9 +668,15 @@ class VisionDirectAdapter(torch.nn.Module):
|
||||
self.resampler = MLPR(
|
||||
in_dim=self.token_size,
|
||||
in_channels=max_seq_len,
|
||||
out_dim=self.token_size,
|
||||
out_dim=self.mid_size,
|
||||
out_channels=self.config.num_tokens,
|
||||
)
|
||||
|
||||
elif self.config.image_encoder_arch == "pixtral":
|
||||
self.resampler = VisionLanguageAdapter(
|
||||
in_dim=self.token_size,
|
||||
out_dim=self.mid_size,
|
||||
)
|
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False):
|
||||
if self.config.train_scaler:
|
||||
@@ -678,7 +697,7 @@ class VisionDirectAdapter(torch.nn.Module):
|
||||
# todo remove this when we have a real solution
|
||||
if self.block_scaler is not None and self.block_scaler.dtype != torch.float32:
|
||||
self.block_scaler.data = self.block_scaler.data.to(torch.float32)
|
||||
if self.config.num_tokens is not None:
|
||||
if self.resampler is not None:
|
||||
input = self.resampler(input)
|
||||
return input
|
||||
|
||||
|
||||
Reference in New Issue
Block a user