From f47f9f1f2c945d4c2c592b7cc7c1ea695cf281b3 Mon Sep 17 00:00:00 2001 From: Jaret Burkett Date: Sun, 10 May 2026 13:44:14 -0600 Subject: [PATCH] Remove dependence on flash attention for hidream o1 --- .../hidream/hidream_o1_model.py | 11 +- .../src/hidream_o1/qwen3_vl_transformers.py | 121 +++++++++--------- 2 files changed, 67 insertions(+), 65 deletions(-) diff --git a/extensions_built_in/diffusion_models/hidream/hidream_o1_model.py b/extensions_built_in/diffusion_models/hidream/hidream_o1_model.py index c461dc31..22449018 100644 --- a/extensions_built_in/diffusion_models/hidream/hidream_o1_model.py +++ b/extensions_built_in/diffusion_models/hidream/hidream_o1_model.py @@ -150,8 +150,12 @@ class HidreamO1Model(BaseModel): model_path = self.model_config.name_or_path self.print_and_status_update("Loading transformer") - - processor = AutoProcessor.from_pretrained(model_path) + + try: + processor = AutoProcessor.from_pretrained(model_path) + except Exception as e: + print(f"Failed to load processor from model path {model_path}, trying original path. Error: {e}") + processor = AutoProcessor.from_pretrained(self.model_config.name_or_path_original) tokenizer = get_tokenizer(processor) add_special_tokens(tokenizer) @@ -452,6 +456,9 @@ class HidreamO1Model(BaseModel): save_directory=output_path, safe_serialization=True, ) + + # save processor + self.tokenizer.save_pretrained(output_path) meta_path = os.path.join(output_path, "aitk_meta.yaml") with open(meta_path, "w") as f: diff --git a/extensions_built_in/diffusion_models/hidream/src/hidream_o1/qwen3_vl_transformers.py b/extensions_built_in/diffusion_models/hidream/src/hidream_o1/qwen3_vl_transformers.py index 0ec7695b..26128e37 100644 --- a/extensions_built_in/diffusion_models/hidream/src/hidream_o1/qwen3_vl_transformers.py +++ b/extensions_built_in/diffusion_models/hidream/src/hidream_o1/qwen3_vl_transformers.py @@ -20,22 +20,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -flash_attn_version = os.environ.get("FA_VERSION", "auto") USE_BF16_ROPE = os.environ.get("USE_BF16_ROPE", "0") -# Flash Attention import (FA3 preferred, FA2 fallback) -_flash_attn_func = None -if flash_attn_version == "2": - from flash_attn import flash_attn_func as _flash_attn_func -elif flash_attn_version == "3": - from flash_attn_interface import flash_attn_func as _flash_attn_func -else: - try: - from flash_attn_interface import flash_attn_func as _flash_attn_func - except ImportError: - try: - from flash_attn import flash_attn_func as _flash_attn_func - except ImportError: - _flash_attn_func = None from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache from transformers.generation import GenerationMixin @@ -1528,7 +1513,7 @@ class Qwen3VLModel(Qwen3VLPreTrainedModel): def _run_decoder_flash( self, inputs_embeds, position_ids, token_types, return_mid_results_layers=None ): - """Run decoder layers with flash attention two-pass approach. + """Run decoder layers with two-pass attention. Replicates the Megatron attention pattern: 1. Causal attention on AR tokens only (text) @@ -1538,15 +1523,15 @@ class Qwen3VLModel(Qwen3VLPreTrainedModel): This ensures AR tokens only attend causally to other AR tokens, while gen tokens attend bidirectionally to everything. + Uses the transformers attention dispatch (ALL_ATTENTION_FUNCTIONS), + so any backend works (sdpa by default, flash_attention_2 if + activated via config._attn_implementation). + Args: inputs_embeds: [batch, total_seq_len, hidden] position_ids: [3, batch, total_seq_len] - 3D RoPE positions token_types: [batch, total_seq_len] - 0=AR, 1=gen """ - assert _flash_attn_func is not None, ( - "Flash attention is not available. Install flash_attn_interface (FA3) or flash_attn (FA2)." - ) - text_model = self.language_model # Compute rotary position embeddings @@ -1566,16 +1551,17 @@ class Qwen3VLModel(Qwen3VLPreTrainedModel): use_gc = text_model.gradient_checkpointing and torch.is_grad_enabled() - def _flash_layer_forward(hidden_states, decoder_layer, cos, sin, idx_ar): - """Flash attention layer forward compatible with FSDP2. + def _two_pass_layer_forward(hidden_states, decoder_layer, cos, sin, idx_ar): + """Two-pass attention layer forward compatible with FSDP2. Calls decoder_layer(...) through its __call__ to trigger FSDP hooks (which swap DTensor parameters to plain tensors), with self_attn.forward - temporarily replaced by a custom two-pass flash attention implementation. + temporarily replaced by a custom two-pass attention implementation + that goes through the transformers attention dispatch. """ original_attn_forward = decoder_layer.self_attn.forward - def _custom_flash_attn( + def _custom_two_pass_attn( hidden_states, position_embeddings, attention_mask=None, **kwargs ): attn = decoder_layer.self_attn @@ -1583,47 +1569,56 @@ class Qwen3VLModel(Qwen3VLPreTrainedModel): head_dim = attn.head_dim hidden_shape = (*input_shape, -1, head_dim) - # Q, K, V projections - q = attn.q_norm(attn.q_proj(hidden_states).view(hidden_shape)) - k = attn.k_norm(attn.k_proj(hidden_states).view(hidden_shape)) - v = attn.v_proj(hidden_states).view(hidden_shape) + # Q, K, V projections in [B, H, S, D] + q = attn.q_norm( + attn.q_proj(hidden_states).view(hidden_shape) + ).transpose(1, 2) + k = attn.k_norm( + attn.k_proj(hidden_states).view(hidden_shape) + ).transpose(1, 2) + v = attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - # Apply rotary position embedding (expects [B, H, S, D]) + # Apply rotary position embedding cos_pe, sin_pe = position_embeddings - q_r = q.transpose(1, 2) # [B, H, S, D] - k_r = k.transpose(1, 2) # [B, KVH, S, D] - q_r, k_r = apply_rotary_pos_emb(q_r, k_r, cos_pe, sin_pe) - q = q_r.transpose(1, 2).contiguous() # [B, S, H, D] - k = k_r.transpose(1, 2).contiguous() # [B, S, KVH, D] - v = v.contiguous() + q, k = apply_rotary_pos_emb(q, k, cos_pe, sin_pe) - softmax_scale = head_dim**-0.5 + scaling = head_dim**-0.5 - # --- Two-pass flash attention --- - # Pass 1: causal attention on AR tokens only - q_ar = q[:, idx_ar].contiguous() - k_ar = k[:, idx_ar].contiguous() - v_ar = v[:, idx_ar].contiguous() - result_ar = _flash_attn_func( - q_ar.to(torch.bfloat16), - k_ar.to(torch.bfloat16), - v_ar.to(torch.bfloat16), - softmax_scale=softmax_scale, - causal=True, - ) - out_ar = result_ar[0] if isinstance(result_ar, tuple) else result_ar + # Attention dispatch — sdpa/eager by default, flash_attention_2 + # if activated via config._attn_implementation. + attention_interface: Callable = eager_attention_forward + if attn.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[ + attn.config._attn_implementation + ] - # Pass 2: full (bidirectional) attention on all tokens - result_full = _flash_attn_func( - q.to(torch.bfloat16), - k.to(torch.bfloat16), - v.to(torch.bfloat16), - softmax_scale=softmax_scale, - causal=False, - ) - out_full = ( - result_full[0] if isinstance(result_full, tuple) else result_full - ) + # --- Two-pass attention --- + # Pass 1: causal on AR tokens only (slice on seq dim) + q_ar = q[:, :, idx_ar].contiguous() + k_ar = k[:, :, idx_ar].contiguous() + v_ar = v[:, :, idx_ar].contiguous() + out_ar, _ = attention_interface( + attn, + q_ar, + k_ar, + v_ar, + attention_mask=None, + dropout=0.0, + scaling=scaling, + is_causal=True, + ) # [B, n_ar, H, D] + + # Pass 2: full (bidirectional) on all tokens + out_full, _ = attention_interface( + attn, + q, + k, + v, + attention_mask=None, + dropout=0.0, + scaling=scaling, + is_causal=False, + ) # [B, S, H, D] # Replace AR positions with causal result out_full = out_full.clone() @@ -1638,7 +1633,7 @@ class Qwen3VLModel(Qwen3VLPreTrainedModel): # to avoid nested checkpointing (the outer loop handles GC). _saved_gc = decoder_layer.gradient_checkpointing decoder_layer.gradient_checkpointing = False - decoder_layer.self_attn.forward = _custom_flash_attn + decoder_layer.self_attn.forward = _custom_two_pass_attn try: hidden_states = decoder_layer( hidden_states, @@ -1653,7 +1648,7 @@ class Qwen3VLModel(Qwen3VLPreTrainedModel): for layer_idx, decoder_layer in enumerate(text_model.layers): if use_gc: hidden_states = torch.utils.checkpoint.checkpoint( - _flash_layer_forward, + _two_pass_layer_forward, hidden_states, decoder_layer, cos, @@ -1662,7 +1657,7 @@ class Qwen3VLModel(Qwen3VLPreTrainedModel): use_reentrant=False, ) else: - hidden_states = _flash_layer_forward( + hidden_states = _two_pass_layer_forward( hidden_states, decoder_layer, cos,