Added some more useful error handeling and logging

This commit is contained in:
Jaret Burkett
2025-04-07 08:01:37 -06:00
parent 7c21eac1b3
commit 6c8b5ab606
2 changed files with 53 additions and 26 deletions

View File

@@ -11,7 +11,7 @@ from diffusers import Transformer2DModel, FluxTransformer2DModel
from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPVisionModelWithProjection
from toolkit.models.pixtral_vision import PixtralVisionEncoder, PixtralVisionImagePreprocessor, VisionLanguageAdapter
from transformers import SiglipImageProcessor, SiglipVisionModel
import traceback
from toolkit.config_modules import AdapterConfig
from toolkit.paths import REPOS_ROOT
sys.path.append(REPOS_ROOT)
@@ -291,36 +291,52 @@ class VisionDirectAdapterAttnProcessor(nn.Module):
# only use one TE or the other. If our adapter is active only use ours
if self.is_active and self.conditional_embeds is not None:
try:
adapter_hidden_states = self.conditional_embeds
if adapter_hidden_states.shape[0] < batch_size:
adapter_hidden_states = torch.cat([
self.unconditional_embeds,
adapter_hidden_states
], dim=0)
# if it is image embeds, we need to add a 1 dim at inx 1
if len(adapter_hidden_states.shape) == 2:
adapter_hidden_states = adapter_hidden_states.unsqueeze(1)
# conditional_batch_size = adapter_hidden_states.shape[0]
# conditional_query = query
adapter_hidden_states = self.conditional_embeds
if adapter_hidden_states.shape[0] == batch_size // 2:
adapter_hidden_states = torch.cat([
self.unconditional_embeds,
adapter_hidden_states
], dim=0)
# if it is image embeds, we need to add a 1 dim at inx 1
if len(adapter_hidden_states.shape) == 2:
adapter_hidden_states = adapter_hidden_states.unsqueeze(1)
# conditional_batch_size = adapter_hidden_states.shape[0]
# conditional_query = query
# for ip-adapter
vd_key = self.to_k_adapter(adapter_hidden_states)
vd_value = self.to_v_adapter(adapter_hidden_states)
# for ip-adapter
vd_key = self.to_k_adapter(adapter_hidden_states)
vd_value = self.to_v_adapter(adapter_hidden_states)
vd_key = vd_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
vd_value = vd_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
vd_key = vd_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
vd_value = vd_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
vd_hidden_states = F.scaled_dot_product_attention(
query, vd_key, vd_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
vd_hidden_states = F.scaled_dot_product_attention(
query, vd_key, vd_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
vd_hidden_states = vd_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
vd_hidden_states = vd_hidden_states.to(query.dtype)
vd_hidden_states = vd_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
vd_hidden_states = vd_hidden_states.to(query.dtype)
hidden_states = hidden_states + self.scale * vd_hidden_states
hidden_states = hidden_states + self.scale * vd_hidden_states
except Exception as e:
print("Error in VisionDirectAdapterAttnProcessor")
# print shapes of all tensors
print(f"hidden_states: {hidden_states.shape}")
print(f"adapter_hidden_states: {adapter_hidden_states.shape}")
print(f"vd_key: {vd_key.shape}")
print(f"vd_value: {vd_value.shape}")
print(f"vd_hidden_states: {vd_hidden_states.shape}")
print(f"query: {query.shape}")
print(f"key: {key.shape}")
print(f"value: {value.shape}")
print(f"inner_dim: {inner_dim}")
print(f"head_dim: {head_dim}")
print(f"batch_size: {batch_size}")
traceback.print_exc()
# linear proj