diff --git a/toolkit/custom_adapter.py b/toolkit/custom_adapter.py index 07a8e180..1d0e28f9 100644 --- a/toolkit/custom_adapter.py +++ b/toolkit/custom_adapter.py @@ -744,6 +744,7 @@ class CustomAdapter(torch.nn.Module): batch_size=1, ) -> PromptEmbeds: if self.adapter_type == 'ilora' or self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter': + skip_unconditional = self.sd_ref().is_flux if tensors_0_1 is None: tensors_0_1 = self.get_empty_clip_image(batch_size) has_been_preprocessed = True @@ -797,7 +798,7 @@ class CustomAdapter(torch.nn.Module): batch_size = clip_image.shape[0] - if self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter': + if (self.adapter_type == 'vision_direct' or self.adapter_type == 'te_augmenter') and not skip_unconditional: # add an unconditional so we can save it unconditional = self.get_empty_clip_image(batch_size, shape=clip_image.shape).to( clip_image.device, dtype=clip_image.dtype @@ -895,7 +896,10 @@ class CustomAdapter(torch.nn.Module): # save them to the conditional and unconditional try: - self.unconditional_embeds, self.conditional_embeds = clip_image_embeds.chunk(2, dim=0) + if skip_unconditional: + self.unconditional_embeds, self.conditional_embeds = None, clip_image_embeds + else: + self.unconditional_embeds, self.conditional_embeds = clip_image_embeds.chunk(2, dim=0) except ValueError: raise ValueError(f"could not split the clip image embeds into 2. Got shape: {clip_image_embeds.shape}") diff --git a/toolkit/models/vd_adapter.py b/toolkit/models/vd_adapter.py index e675e0b1..83407c9d 100644 --- a/toolkit/models/vd_adapter.py +++ b/toolkit/models/vd_adapter.py @@ -9,7 +9,7 @@ from collections import OrderedDict from diffusers import Transformer2DModel, FluxTransformer2DModel from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPVisionModelWithProjection -from toolkit.models.pixtral_vision import PixtralVisionEncoder, PixtralVisionImagePreprocessor +from toolkit.models.pixtral_vision import PixtralVisionEncoder, PixtralVisionImagePreprocessor, VisionLanguageAdapter from toolkit.config_modules import AdapterConfig from toolkit.paths import REPOS_ROOT @@ -462,11 +462,23 @@ class VisionDirectAdapter(torch.nn.Module): self.sd_ref: weakref.ref = weakref.ref(sd) self.config: AdapterConfig = adapter.config self.vision_model_ref: weakref.ref = weakref.ref(vision_model) + self.resampler = None + is_pixtral = self.config.image_encoder_arch == "pixtral" if adapter.config.clip_layer == "image_embeds": self.token_size = vision_model.config.projection_dim else: self.token_size = vision_model.config.hidden_size + + self.mid_size = self.token_size + + # if pixtral, use cross attn dim for more sparse representation + if is_pixtral: + if is_flux: + hidden_size = 3072 + else: + hidden_size = sd.unet.config['cross_attention_dim'] + self.mid_size = hidden_size # init adapter modules attn_procs = {} @@ -487,9 +499,10 @@ class VisionDirectAdapter(torch.nn.Module): for i, module in transformer.transformer_blocks.named_children(): attn_processor_keys.append(f"transformer_blocks.{i}.attn") - # single transformer blocks do not have cross attn, but we will do them anyway - for i, module in transformer.single_transformer_blocks.named_children(): - attn_processor_keys.append(f"single_transformer_blocks.{i}.attn") + if not self.config.flux_only_double: + # single transformer blocks do not have cross attn, but we will do them anyway + for i, module in transformer.single_transformer_blocks.named_children(): + attn_processor_keys.append(f"single_transformer_blocks.{i}.attn") else: attn_processor_keys = list(sd.unet.attn_processors.keys()) @@ -532,27 +545,27 @@ class VisionDirectAdapter(torch.nn.Module): to_v_adapter = unet_sd[layer_name + ".to_v.weight"] # add zero padding to the adapter - if to_k_adapter.shape[1] < self.token_size: + if to_k_adapter.shape[1] < self.mid_size: to_k_adapter = torch.cat([ to_k_adapter, - torch.randn(to_k_adapter.shape[0], self.token_size - to_k_adapter.shape[1]).to( + torch.randn(to_k_adapter.shape[0], self.mid_size - to_k_adapter.shape[1]).to( to_k_adapter.device, dtype=to_k_adapter.dtype) * 0.01 ], dim=1 ) to_v_adapter = torch.cat([ to_v_adapter, - torch.randn(to_v_adapter.shape[0], self.token_size - to_v_adapter.shape[1]).to( + torch.randn(to_v_adapter.shape[0], self.mid_size - to_v_adapter.shape[1]).to( to_k_adapter.device, dtype=to_k_adapter.dtype) * 0.01 ], dim=1 ) - elif to_k_adapter.shape[1] > self.token_size: - to_k_adapter = to_k_adapter[:, :self.token_size] - to_v_adapter = to_v_adapter[:, :self.token_size] + elif to_k_adapter.shape[1] > self.mid_size: + to_k_adapter = to_k_adapter[:, :self.mid_size] + to_v_adapter = to_v_adapter[:, :self.mid_size] # if is_pixart: - # to_k_bias = to_k_bias[:self.token_size] - # to_v_bias = to_v_bias[:self.token_size] + # to_k_bias = to_k_bias[:self.mid_size] + # to_v_bias = to_v_bias[:self.mid_size] else: to_k_adapter = to_k_adapter to_v_adapter = to_v_adapter @@ -574,7 +587,7 @@ class VisionDirectAdapter(torch.nn.Module): cross_attention_dim=cross_attention_dim, scale=1.0, adapter=self, - adapter_hidden_size=self.token_size, + adapter_hidden_size=self.mid_size, has_bias=False, block_idx=current_idx ) @@ -584,7 +597,7 @@ class VisionDirectAdapter(torch.nn.Module): cross_attention_dim=cross_attention_dim, scale=1.0, adapter=self, - adapter_hidden_size=self.token_size, + adapter_hidden_size=self.mid_size, has_bias=False, ) current_idx += 1 @@ -655,9 +668,15 @@ class VisionDirectAdapter(torch.nn.Module): self.resampler = MLPR( in_dim=self.token_size, in_channels=max_seq_len, - out_dim=self.token_size, + out_dim=self.mid_size, out_channels=self.config.num_tokens, ) + + elif self.config.image_encoder_arch == "pixtral": + self.resampler = VisionLanguageAdapter( + in_dim=self.token_size, + out_dim=self.mid_size, + ) def state_dict(self, destination=None, prefix='', keep_vars=False): if self.config.train_scaler: @@ -678,7 +697,7 @@ class VisionDirectAdapter(torch.nn.Module): # todo remove this when we have a real solution if self.block_scaler is not None and self.block_scaler.dtype != torch.float32: self.block_scaler.data = self.block_scaler.data.to(torch.float32) - if self.config.num_tokens is not None: + if self.resampler is not None: input = self.resampler(input) return input