mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-05-11 08:20:35 +00:00
Remove dependence on flash attention for hidream o1
This commit is contained in:
@@ -150,8 +150,12 @@ class HidreamO1Model(BaseModel):
|
||||
model_path = self.model_config.name_or_path
|
||||
|
||||
self.print_and_status_update("Loading transformer")
|
||||
|
||||
processor = AutoProcessor.from_pretrained(model_path)
|
||||
|
||||
try:
|
||||
processor = AutoProcessor.from_pretrained(model_path)
|
||||
except Exception as e:
|
||||
print(f"Failed to load processor from model path {model_path}, trying original path. Error: {e}")
|
||||
processor = AutoProcessor.from_pretrained(self.model_config.name_or_path_original)
|
||||
|
||||
tokenizer = get_tokenizer(processor)
|
||||
add_special_tokens(tokenizer)
|
||||
@@ -452,6 +456,9 @@ class HidreamO1Model(BaseModel):
|
||||
save_directory=output_path,
|
||||
safe_serialization=True,
|
||||
)
|
||||
|
||||
# save processor
|
||||
self.tokenizer.save_pretrained(output_path)
|
||||
|
||||
meta_path = os.path.join(output_path, "aitk_meta.yaml")
|
||||
with open(meta_path, "w") as f:
|
||||
|
||||
@@ -20,22 +20,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
flash_attn_version = os.environ.get("FA_VERSION", "auto")
|
||||
USE_BF16_ROPE = os.environ.get("USE_BF16_ROPE", "0")
|
||||
# Flash Attention import (FA3 preferred, FA2 fallback)
|
||||
_flash_attn_func = None
|
||||
if flash_attn_version == "2":
|
||||
from flash_attn import flash_attn_func as _flash_attn_func
|
||||
elif flash_attn_version == "3":
|
||||
from flash_attn_interface import flash_attn_func as _flash_attn_func
|
||||
else:
|
||||
try:
|
||||
from flash_attn_interface import flash_attn_func as _flash_attn_func
|
||||
except ImportError:
|
||||
try:
|
||||
from flash_attn import flash_attn_func as _flash_attn_func
|
||||
except ImportError:
|
||||
_flash_attn_func = None
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.cache_utils import Cache, DynamicCache
|
||||
from transformers.generation import GenerationMixin
|
||||
@@ -1528,7 +1513,7 @@ class Qwen3VLModel(Qwen3VLPreTrainedModel):
|
||||
def _run_decoder_flash(
|
||||
self, inputs_embeds, position_ids, token_types, return_mid_results_layers=None
|
||||
):
|
||||
"""Run decoder layers with flash attention two-pass approach.
|
||||
"""Run decoder layers with two-pass attention.
|
||||
|
||||
Replicates the Megatron attention pattern:
|
||||
1. Causal attention on AR tokens only (text)
|
||||
@@ -1538,15 +1523,15 @@ class Qwen3VLModel(Qwen3VLPreTrainedModel):
|
||||
This ensures AR tokens only attend causally to other AR tokens,
|
||||
while gen tokens attend bidirectionally to everything.
|
||||
|
||||
Uses the transformers attention dispatch (ALL_ATTENTION_FUNCTIONS),
|
||||
so any backend works (sdpa by default, flash_attention_2 if
|
||||
activated via config._attn_implementation).
|
||||
|
||||
Args:
|
||||
inputs_embeds: [batch, total_seq_len, hidden]
|
||||
position_ids: [3, batch, total_seq_len] - 3D RoPE positions
|
||||
token_types: [batch, total_seq_len] - 0=AR, 1=gen
|
||||
"""
|
||||
assert _flash_attn_func is not None, (
|
||||
"Flash attention is not available. Install flash_attn_interface (FA3) or flash_attn (FA2)."
|
||||
)
|
||||
|
||||
text_model = self.language_model
|
||||
|
||||
# Compute rotary position embeddings
|
||||
@@ -1566,16 +1551,17 @@ class Qwen3VLModel(Qwen3VLPreTrainedModel):
|
||||
|
||||
use_gc = text_model.gradient_checkpointing and torch.is_grad_enabled()
|
||||
|
||||
def _flash_layer_forward(hidden_states, decoder_layer, cos, sin, idx_ar):
|
||||
"""Flash attention layer forward compatible with FSDP2.
|
||||
def _two_pass_layer_forward(hidden_states, decoder_layer, cos, sin, idx_ar):
|
||||
"""Two-pass attention layer forward compatible with FSDP2.
|
||||
|
||||
Calls decoder_layer(...) through its __call__ to trigger FSDP hooks
|
||||
(which swap DTensor parameters to plain tensors), with self_attn.forward
|
||||
temporarily replaced by a custom two-pass flash attention implementation.
|
||||
temporarily replaced by a custom two-pass attention implementation
|
||||
that goes through the transformers attention dispatch.
|
||||
"""
|
||||
original_attn_forward = decoder_layer.self_attn.forward
|
||||
|
||||
def _custom_flash_attn(
|
||||
def _custom_two_pass_attn(
|
||||
hidden_states, position_embeddings, attention_mask=None, **kwargs
|
||||
):
|
||||
attn = decoder_layer.self_attn
|
||||
@@ -1583,47 +1569,56 @@ class Qwen3VLModel(Qwen3VLPreTrainedModel):
|
||||
head_dim = attn.head_dim
|
||||
hidden_shape = (*input_shape, -1, head_dim)
|
||||
|
||||
# Q, K, V projections
|
||||
q = attn.q_norm(attn.q_proj(hidden_states).view(hidden_shape))
|
||||
k = attn.k_norm(attn.k_proj(hidden_states).view(hidden_shape))
|
||||
v = attn.v_proj(hidden_states).view(hidden_shape)
|
||||
# Q, K, V projections in [B, H, S, D]
|
||||
q = attn.q_norm(
|
||||
attn.q_proj(hidden_states).view(hidden_shape)
|
||||
).transpose(1, 2)
|
||||
k = attn.k_norm(
|
||||
attn.k_proj(hidden_states).view(hidden_shape)
|
||||
).transpose(1, 2)
|
||||
v = attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
# Apply rotary position embedding (expects [B, H, S, D])
|
||||
# Apply rotary position embedding
|
||||
cos_pe, sin_pe = position_embeddings
|
||||
q_r = q.transpose(1, 2) # [B, H, S, D]
|
||||
k_r = k.transpose(1, 2) # [B, KVH, S, D]
|
||||
q_r, k_r = apply_rotary_pos_emb(q_r, k_r, cos_pe, sin_pe)
|
||||
q = q_r.transpose(1, 2).contiguous() # [B, S, H, D]
|
||||
k = k_r.transpose(1, 2).contiguous() # [B, S, KVH, D]
|
||||
v = v.contiguous()
|
||||
q, k = apply_rotary_pos_emb(q, k, cos_pe, sin_pe)
|
||||
|
||||
softmax_scale = head_dim**-0.5
|
||||
scaling = head_dim**-0.5
|
||||
|
||||
# --- Two-pass flash attention ---
|
||||
# Pass 1: causal attention on AR tokens only
|
||||
q_ar = q[:, idx_ar].contiguous()
|
||||
k_ar = k[:, idx_ar].contiguous()
|
||||
v_ar = v[:, idx_ar].contiguous()
|
||||
result_ar = _flash_attn_func(
|
||||
q_ar.to(torch.bfloat16),
|
||||
k_ar.to(torch.bfloat16),
|
||||
v_ar.to(torch.bfloat16),
|
||||
softmax_scale=softmax_scale,
|
||||
causal=True,
|
||||
)
|
||||
out_ar = result_ar[0] if isinstance(result_ar, tuple) else result_ar
|
||||
# Attention dispatch — sdpa/eager by default, flash_attention_2
|
||||
# if activated via config._attn_implementation.
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if attn.config._attn_implementation != "eager":
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[
|
||||
attn.config._attn_implementation
|
||||
]
|
||||
|
||||
# Pass 2: full (bidirectional) attention on all tokens
|
||||
result_full = _flash_attn_func(
|
||||
q.to(torch.bfloat16),
|
||||
k.to(torch.bfloat16),
|
||||
v.to(torch.bfloat16),
|
||||
softmax_scale=softmax_scale,
|
||||
causal=False,
|
||||
)
|
||||
out_full = (
|
||||
result_full[0] if isinstance(result_full, tuple) else result_full
|
||||
)
|
||||
# --- Two-pass attention ---
|
||||
# Pass 1: causal on AR tokens only (slice on seq dim)
|
||||
q_ar = q[:, :, idx_ar].contiguous()
|
||||
k_ar = k[:, :, idx_ar].contiguous()
|
||||
v_ar = v[:, :, idx_ar].contiguous()
|
||||
out_ar, _ = attention_interface(
|
||||
attn,
|
||||
q_ar,
|
||||
k_ar,
|
||||
v_ar,
|
||||
attention_mask=None,
|
||||
dropout=0.0,
|
||||
scaling=scaling,
|
||||
is_causal=True,
|
||||
) # [B, n_ar, H, D]
|
||||
|
||||
# Pass 2: full (bidirectional) on all tokens
|
||||
out_full, _ = attention_interface(
|
||||
attn,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
attention_mask=None,
|
||||
dropout=0.0,
|
||||
scaling=scaling,
|
||||
is_causal=False,
|
||||
) # [B, S, H, D]
|
||||
|
||||
# Replace AR positions with causal result
|
||||
out_full = out_full.clone()
|
||||
@@ -1638,7 +1633,7 @@ class Qwen3VLModel(Qwen3VLPreTrainedModel):
|
||||
# to avoid nested checkpointing (the outer loop handles GC).
|
||||
_saved_gc = decoder_layer.gradient_checkpointing
|
||||
decoder_layer.gradient_checkpointing = False
|
||||
decoder_layer.self_attn.forward = _custom_flash_attn
|
||||
decoder_layer.self_attn.forward = _custom_two_pass_attn
|
||||
try:
|
||||
hidden_states = decoder_layer(
|
||||
hidden_states,
|
||||
@@ -1653,7 +1648,7 @@ class Qwen3VLModel(Qwen3VLPreTrainedModel):
|
||||
for layer_idx, decoder_layer in enumerate(text_model.layers):
|
||||
if use_gc:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
_flash_layer_forward,
|
||||
_two_pass_layer_forward,
|
||||
hidden_states,
|
||||
decoder_layer,
|
||||
cos,
|
||||
@@ -1662,7 +1657,7 @@ class Qwen3VLModel(Qwen3VLPreTrainedModel):
|
||||
use_reentrant=False,
|
||||
)
|
||||
else:
|
||||
hidden_states = _flash_layer_forward(
|
||||
hidden_states = _two_pass_layer_forward(
|
||||
hidden_states,
|
||||
decoder_layer,
|
||||
cos,
|
||||
|
||||
Reference in New Issue
Block a user