mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Added some more useful error handeling and logging
This commit is contained in:
@@ -11,7 +11,7 @@ from diffusers import Transformer2DModel, FluxTransformer2DModel
|
||||
from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPVisionModelWithProjection
|
||||
from toolkit.models.pixtral_vision import PixtralVisionEncoder, PixtralVisionImagePreprocessor, VisionLanguageAdapter
|
||||
from transformers import SiglipImageProcessor, SiglipVisionModel
|
||||
|
||||
import traceback
|
||||
from toolkit.config_modules import AdapterConfig
|
||||
from toolkit.paths import REPOS_ROOT
|
||||
sys.path.append(REPOS_ROOT)
|
||||
@@ -291,36 +291,52 @@ class VisionDirectAdapterAttnProcessor(nn.Module):
|
||||
|
||||
# only use one TE or the other. If our adapter is active only use ours
|
||||
if self.is_active and self.conditional_embeds is not None:
|
||||
try:
|
||||
|
||||
adapter_hidden_states = self.conditional_embeds
|
||||
if adapter_hidden_states.shape[0] < batch_size:
|
||||
adapter_hidden_states = torch.cat([
|
||||
self.unconditional_embeds,
|
||||
adapter_hidden_states
|
||||
], dim=0)
|
||||
# if it is image embeds, we need to add a 1 dim at inx 1
|
||||
if len(adapter_hidden_states.shape) == 2:
|
||||
adapter_hidden_states = adapter_hidden_states.unsqueeze(1)
|
||||
# conditional_batch_size = adapter_hidden_states.shape[0]
|
||||
# conditional_query = query
|
||||
adapter_hidden_states = self.conditional_embeds
|
||||
if adapter_hidden_states.shape[0] == batch_size // 2:
|
||||
adapter_hidden_states = torch.cat([
|
||||
self.unconditional_embeds,
|
||||
adapter_hidden_states
|
||||
], dim=0)
|
||||
# if it is image embeds, we need to add a 1 dim at inx 1
|
||||
if len(adapter_hidden_states.shape) == 2:
|
||||
adapter_hidden_states = adapter_hidden_states.unsqueeze(1)
|
||||
# conditional_batch_size = adapter_hidden_states.shape[0]
|
||||
# conditional_query = query
|
||||
|
||||
# for ip-adapter
|
||||
vd_key = self.to_k_adapter(adapter_hidden_states)
|
||||
vd_value = self.to_v_adapter(adapter_hidden_states)
|
||||
# for ip-adapter
|
||||
vd_key = self.to_k_adapter(adapter_hidden_states)
|
||||
vd_value = self.to_v_adapter(adapter_hidden_states)
|
||||
|
||||
vd_key = vd_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
vd_value = vd_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
vd_key = vd_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
vd_value = vd_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
vd_hidden_states = F.scaled_dot_product_attention(
|
||||
query, vd_key, vd_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
vd_hidden_states = F.scaled_dot_product_attention(
|
||||
query, vd_key, vd_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
vd_hidden_states = vd_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
vd_hidden_states = vd_hidden_states.to(query.dtype)
|
||||
vd_hidden_states = vd_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
vd_hidden_states = vd_hidden_states.to(query.dtype)
|
||||
|
||||
hidden_states = hidden_states + self.scale * vd_hidden_states
|
||||
hidden_states = hidden_states + self.scale * vd_hidden_states
|
||||
except Exception as e:
|
||||
print("Error in VisionDirectAdapterAttnProcessor")
|
||||
# print shapes of all tensors
|
||||
print(f"hidden_states: {hidden_states.shape}")
|
||||
print(f"adapter_hidden_states: {adapter_hidden_states.shape}")
|
||||
print(f"vd_key: {vd_key.shape}")
|
||||
print(f"vd_value: {vd_value.shape}")
|
||||
print(f"vd_hidden_states: {vd_hidden_states.shape}")
|
||||
print(f"query: {query.shape}")
|
||||
print(f"key: {key.shape}")
|
||||
print(f"value: {value.shape}")
|
||||
print(f"inner_dim: {inner_dim}")
|
||||
print(f"head_dim: {head_dim}")
|
||||
print(f"batch_size: {batch_size}")
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
# linear proj
|
||||
|
||||
Reference in New Issue
Block a user