diff --git a/jobs/process/BaseSDTrainProcess.py b/jobs/process/BaseSDTrainProcess.py index e37edd3a..3c7f57ed 100644 --- a/jobs/process/BaseSDTrainProcess.py +++ b/jobs/process/BaseSDTrainProcess.py @@ -7,6 +7,7 @@ import shutil from collections import OrderedDict import os import re +import traceback from typing import Union, List, Optional import numpy as np @@ -2008,7 +2009,17 @@ class BaseSDTrainProcess(BaseTrainProcess): # flush() ### HOOK ### with self.accelerator.accumulate(self.modules_being_trained): - loss_dict = self.hook_train_loop(batch_list) + try: + loss_dict = self.hook_train_loop(batch_list) + except Exception as e: + traceback.print_exc() + #print batch info + print("Batch Items:") + for batch in batch_list: + for item in batch.file_items: + print(f" - {item.path}") + raise e + self.timer.stop('train_loop') if not did_first_flush: flush() diff --git a/toolkit/models/vd_adapter.py b/toolkit/models/vd_adapter.py index ea3f9bc7..52c38cec 100644 --- a/toolkit/models/vd_adapter.py +++ b/toolkit/models/vd_adapter.py @@ -11,7 +11,7 @@ from diffusers import Transformer2DModel, FluxTransformer2DModel from transformers import T5EncoderModel, CLIPTextModel, CLIPTokenizer, T5Tokenizer, CLIPVisionModelWithProjection from toolkit.models.pixtral_vision import PixtralVisionEncoder, PixtralVisionImagePreprocessor, VisionLanguageAdapter from transformers import SiglipImageProcessor, SiglipVisionModel - +import traceback from toolkit.config_modules import AdapterConfig from toolkit.paths import REPOS_ROOT sys.path.append(REPOS_ROOT) @@ -291,36 +291,52 @@ class VisionDirectAdapterAttnProcessor(nn.Module): # only use one TE or the other. If our adapter is active only use ours if self.is_active and self.conditional_embeds is not None: + try: - adapter_hidden_states = self.conditional_embeds - if adapter_hidden_states.shape[0] < batch_size: - adapter_hidden_states = torch.cat([ - self.unconditional_embeds, - adapter_hidden_states - ], dim=0) - # if it is image embeds, we need to add a 1 dim at inx 1 - if len(adapter_hidden_states.shape) == 2: - adapter_hidden_states = adapter_hidden_states.unsqueeze(1) - # conditional_batch_size = adapter_hidden_states.shape[0] - # conditional_query = query + adapter_hidden_states = self.conditional_embeds + if adapter_hidden_states.shape[0] == batch_size // 2: + adapter_hidden_states = torch.cat([ + self.unconditional_embeds, + adapter_hidden_states + ], dim=0) + # if it is image embeds, we need to add a 1 dim at inx 1 + if len(adapter_hidden_states.shape) == 2: + adapter_hidden_states = adapter_hidden_states.unsqueeze(1) + # conditional_batch_size = adapter_hidden_states.shape[0] + # conditional_query = query - # for ip-adapter - vd_key = self.to_k_adapter(adapter_hidden_states) - vd_value = self.to_v_adapter(adapter_hidden_states) + # for ip-adapter + vd_key = self.to_k_adapter(adapter_hidden_states) + vd_value = self.to_v_adapter(adapter_hidden_states) - vd_key = vd_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - vd_value = vd_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + vd_key = vd_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + vd_value = vd_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - vd_hidden_states = F.scaled_dot_product_attention( - query, vd_key, vd_value, attn_mask=None, dropout_p=0.0, is_causal=False - ) + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + vd_hidden_states = F.scaled_dot_product_attention( + query, vd_key, vd_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) - vd_hidden_states = vd_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - vd_hidden_states = vd_hidden_states.to(query.dtype) + vd_hidden_states = vd_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + vd_hidden_states = vd_hidden_states.to(query.dtype) - hidden_states = hidden_states + self.scale * vd_hidden_states + hidden_states = hidden_states + self.scale * vd_hidden_states + except Exception as e: + print("Error in VisionDirectAdapterAttnProcessor") + # print shapes of all tensors + print(f"hidden_states: {hidden_states.shape}") + print(f"adapter_hidden_states: {adapter_hidden_states.shape}") + print(f"vd_key: {vd_key.shape}") + print(f"vd_value: {vd_value.shape}") + print(f"vd_hidden_states: {vd_hidden_states.shape}") + print(f"query: {query.shape}") + print(f"key: {key.shape}") + print(f"value: {value.shape}") + print(f"inner_dim: {inner_dim}") + print(f"head_dim: {head_dim}") + print(f"batch_size: {batch_size}") + traceback.print_exc() # linear proj