Remove dependence on flash attention for hidream o1

This commit is contained in:
Jaret Burkett
2026-05-10 13:44:14 -06:00
parent bbdb05d744
commit f47f9f1f2c
2 changed files with 67 additions and 65 deletions

View File

@@ -150,8 +150,12 @@ class HidreamO1Model(BaseModel):
model_path = self.model_config.name_or_path
self.print_and_status_update("Loading transformer")
processor = AutoProcessor.from_pretrained(model_path)
try:
processor = AutoProcessor.from_pretrained(model_path)
except Exception as e:
print(f"Failed to load processor from model path {model_path}, trying original path. Error: {e}")
processor = AutoProcessor.from_pretrained(self.model_config.name_or_path_original)
tokenizer = get_tokenizer(processor)
add_special_tokens(tokenizer)
@@ -452,6 +456,9 @@ class HidreamO1Model(BaseModel):
save_directory=output_path,
safe_serialization=True,
)
# save processor
self.tokenizer.save_pretrained(output_path)
meta_path = os.path.join(output_path, "aitk_meta.yaml")
with open(meta_path, "w") as f:

View File

@@ -20,22 +20,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
flash_attn_version = os.environ.get("FA_VERSION", "auto")
USE_BF16_ROPE = os.environ.get("USE_BF16_ROPE", "0")
# Flash Attention import (FA3 preferred, FA2 fallback)
_flash_attn_func = None
if flash_attn_version == "2":
from flash_attn import flash_attn_func as _flash_attn_func
elif flash_attn_version == "3":
from flash_attn_interface import flash_attn_func as _flash_attn_func
else:
try:
from flash_attn_interface import flash_attn_func as _flash_attn_func
except ImportError:
try:
from flash_attn import flash_attn_func as _flash_attn_func
except ImportError:
_flash_attn_func = None
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache
from transformers.generation import GenerationMixin
@@ -1528,7 +1513,7 @@ class Qwen3VLModel(Qwen3VLPreTrainedModel):
def _run_decoder_flash(
self, inputs_embeds, position_ids, token_types, return_mid_results_layers=None
):
"""Run decoder layers with flash attention two-pass approach.
"""Run decoder layers with two-pass attention.
Replicates the Megatron attention pattern:
1. Causal attention on AR tokens only (text)
@@ -1538,15 +1523,15 @@ class Qwen3VLModel(Qwen3VLPreTrainedModel):
This ensures AR tokens only attend causally to other AR tokens,
while gen tokens attend bidirectionally to everything.
Uses the transformers attention dispatch (ALL_ATTENTION_FUNCTIONS),
so any backend works (sdpa by default, flash_attention_2 if
activated via config._attn_implementation).
Args:
inputs_embeds: [batch, total_seq_len, hidden]
position_ids: [3, batch, total_seq_len] - 3D RoPE positions
token_types: [batch, total_seq_len] - 0=AR, 1=gen
"""
assert _flash_attn_func is not None, (
"Flash attention is not available. Install flash_attn_interface (FA3) or flash_attn (FA2)."
)
text_model = self.language_model
# Compute rotary position embeddings
@@ -1566,16 +1551,17 @@ class Qwen3VLModel(Qwen3VLPreTrainedModel):
use_gc = text_model.gradient_checkpointing and torch.is_grad_enabled()
def _flash_layer_forward(hidden_states, decoder_layer, cos, sin, idx_ar):
"""Flash attention layer forward compatible with FSDP2.
def _two_pass_layer_forward(hidden_states, decoder_layer, cos, sin, idx_ar):
"""Two-pass attention layer forward compatible with FSDP2.
Calls decoder_layer(...) through its __call__ to trigger FSDP hooks
(which swap DTensor parameters to plain tensors), with self_attn.forward
temporarily replaced by a custom two-pass flash attention implementation.
temporarily replaced by a custom two-pass attention implementation
that goes through the transformers attention dispatch.
"""
original_attn_forward = decoder_layer.self_attn.forward
def _custom_flash_attn(
def _custom_two_pass_attn(
hidden_states, position_embeddings, attention_mask=None, **kwargs
):
attn = decoder_layer.self_attn
@@ -1583,47 +1569,56 @@ class Qwen3VLModel(Qwen3VLPreTrainedModel):
head_dim = attn.head_dim
hidden_shape = (*input_shape, -1, head_dim)
# Q, K, V projections
q = attn.q_norm(attn.q_proj(hidden_states).view(hidden_shape))
k = attn.k_norm(attn.k_proj(hidden_states).view(hidden_shape))
v = attn.v_proj(hidden_states).view(hidden_shape)
# Q, K, V projections in [B, H, S, D]
q = attn.q_norm(
attn.q_proj(hidden_states).view(hidden_shape)
).transpose(1, 2)
k = attn.k_norm(
attn.k_proj(hidden_states).view(hidden_shape)
).transpose(1, 2)
v = attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
# Apply rotary position embedding (expects [B, H, S, D])
# Apply rotary position embedding
cos_pe, sin_pe = position_embeddings
q_r = q.transpose(1, 2) # [B, H, S, D]
k_r = k.transpose(1, 2) # [B, KVH, S, D]
q_r, k_r = apply_rotary_pos_emb(q_r, k_r, cos_pe, sin_pe)
q = q_r.transpose(1, 2).contiguous() # [B, S, H, D]
k = k_r.transpose(1, 2).contiguous() # [B, S, KVH, D]
v = v.contiguous()
q, k = apply_rotary_pos_emb(q, k, cos_pe, sin_pe)
softmax_scale = head_dim**-0.5
scaling = head_dim**-0.5
# --- Two-pass flash attention ---
# Pass 1: causal attention on AR tokens only
q_ar = q[:, idx_ar].contiguous()
k_ar = k[:, idx_ar].contiguous()
v_ar = v[:, idx_ar].contiguous()
result_ar = _flash_attn_func(
q_ar.to(torch.bfloat16),
k_ar.to(torch.bfloat16),
v_ar.to(torch.bfloat16),
softmax_scale=softmax_scale,
causal=True,
)
out_ar = result_ar[0] if isinstance(result_ar, tuple) else result_ar
# Attention dispatch — sdpa/eager by default, flash_attention_2
# if activated via config._attn_implementation.
attention_interface: Callable = eager_attention_forward
if attn.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[
attn.config._attn_implementation
]
# Pass 2: full (bidirectional) attention on all tokens
result_full = _flash_attn_func(
q.to(torch.bfloat16),
k.to(torch.bfloat16),
v.to(torch.bfloat16),
softmax_scale=softmax_scale,
causal=False,
)
out_full = (
result_full[0] if isinstance(result_full, tuple) else result_full
)
# --- Two-pass attention ---
# Pass 1: causal on AR tokens only (slice on seq dim)
q_ar = q[:, :, idx_ar].contiguous()
k_ar = k[:, :, idx_ar].contiguous()
v_ar = v[:, :, idx_ar].contiguous()
out_ar, _ = attention_interface(
attn,
q_ar,
k_ar,
v_ar,
attention_mask=None,
dropout=0.0,
scaling=scaling,
is_causal=True,
) # [B, n_ar, H, D]
# Pass 2: full (bidirectional) on all tokens
out_full, _ = attention_interface(
attn,
q,
k,
v,
attention_mask=None,
dropout=0.0,
scaling=scaling,
is_causal=False,
) # [B, S, H, D]
# Replace AR positions with causal result
out_full = out_full.clone()
@@ -1638,7 +1633,7 @@ class Qwen3VLModel(Qwen3VLPreTrainedModel):
# to avoid nested checkpointing (the outer loop handles GC).
_saved_gc = decoder_layer.gradient_checkpointing
decoder_layer.gradient_checkpointing = False
decoder_layer.self_attn.forward = _custom_flash_attn
decoder_layer.self_attn.forward = _custom_two_pass_attn
try:
hidden_states = decoder_layer(
hidden_states,
@@ -1653,7 +1648,7 @@ class Qwen3VLModel(Qwen3VLPreTrainedModel):
for layer_idx, decoder_layer in enumerate(text_model.layers):
if use_gc:
hidden_states = torch.utils.checkpoint.checkpoint(
_flash_layer_forward,
_two_pass_layer_forward,
hidden_states,
decoder_layer,
cos,
@@ -1662,7 +1657,7 @@ class Qwen3VLModel(Qwen3VLPreTrainedModel):
use_reentrant=False,
)
else:
hidden_states = _flash_layer_forward(
hidden_states = _two_pass_layer_forward(
hidden_states,
decoder_layer,
cos,