mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-21 06:49:08 +00:00
Make flash attn optional. Handle larger batch sizes.
This commit is contained in:
@@ -2,12 +2,20 @@ from typing import Optional
|
||||
import torch
|
||||
from .attention import HiDreamAttention
|
||||
|
||||
# Try to import Flash Attention first
|
||||
flash_attn_available = False
|
||||
try:
|
||||
from flash_attn_interface import flash_attn_func
|
||||
USE_FLASH_ATTN3 = True
|
||||
except:
|
||||
from flash_attn import flash_attn_func
|
||||
USE_FLASH_ATTN3 = False
|
||||
flash_attn_available = True
|
||||
except ImportError:
|
||||
try:
|
||||
from flash_attn import flash_attn_func
|
||||
USE_FLASH_ATTN3 = False
|
||||
flash_attn_available = True
|
||||
except ImportError:
|
||||
USE_FLASH_ATTN3 = False
|
||||
flash_attn_available = False
|
||||
|
||||
# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py
|
||||
def apply_rope(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -18,10 +26,28 @@ def apply_rope(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> t
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
|
||||
def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
|
||||
if USE_FLASH_ATTN3:
|
||||
hidden_states = flash_attn_func(query, key, value, causal=False, deterministic=False)[0]
|
||||
if flash_attn_available:
|
||||
if USE_FLASH_ATTN3:
|
||||
hidden_states = flash_attn_func(query, key, value, causal=False, deterministic=False)[0]
|
||||
else:
|
||||
hidden_states = flash_attn_func(query, key, value, dropout_p=0., causal=False)
|
||||
else:
|
||||
hidden_states = flash_attn_func(query, key, value, dropout_p=0., causal=False)
|
||||
# Use torch's scaled dot-product attention as fallback
|
||||
# Reshape for torch.nn.functional.scaled_dot_product_attention which expects [batch, heads, seq_len, head_dim]
|
||||
query = query.transpose(1, 2) # [batch, heads, seq_len, head_dim]
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(
|
||||
query, key, value,
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
is_causal=False
|
||||
)
|
||||
|
||||
# Restore original shape
|
||||
hidden_states = hidden_states.transpose(1, 2) # [batch, seq_len, heads, head_dim]
|
||||
|
||||
hidden_states = hidden_states.flatten(-2)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
return hidden_states
|
||||
|
||||
Reference in New Issue
Block a user