Make flash attn optional. Handle larger batch sizes.

This commit is contained in:
Jaret Burkett
2025-04-14 14:34:46 +00:00
parent 89c0f688db
commit 524bd2edfc
3 changed files with 34 additions and 15 deletions

View File

@@ -58,7 +58,8 @@ RUN echo "Cache bust: ${CACHEBUST}" && \
WORKDIR /app/ai-toolkit
# Install Python dependencies
RUN pip install --no-cache-dir -r requirements.txt
RUN pip install --no-cache-dir -r requirements.txt && \
pip install flash-attn --no-build-isolation --no-cache-dir
# Build UI
WORKDIR /app/ai-toolkit/ui

View File

@@ -129,14 +129,6 @@ class HidreamModel(BaseModel):
torch_dtype=torch.bfloat16
)
# count the params
sd = transformer.state_dict()
num_params = sum(p.numel() for p in sd.values())
print(f"Number of params in transformer: {num_params}")
# count params with name expert in them
num_expert_params = sum(p.numel() for k, p in sd.items() if 'expert' in k)
print(f"Number of params in transformer with expert in them: {num_expert_params}")
if not self.low_vram:
transformer.to(self.device_torch, dtype=dtype)

View File

@@ -2,12 +2,20 @@ from typing import Optional
import torch
from .attention import HiDreamAttention
# Try to import Flash Attention first
flash_attn_available = False
try:
from flash_attn_interface import flash_attn_func
USE_FLASH_ATTN3 = True
except:
from flash_attn import flash_attn_func
USE_FLASH_ATTN3 = False
flash_attn_available = True
except ImportError:
try:
from flash_attn import flash_attn_func
USE_FLASH_ATTN3 = False
flash_attn_available = True
except ImportError:
USE_FLASH_ATTN3 = False
flash_attn_available = False
# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py
def apply_rope(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
@@ -18,10 +26,28 @@ def apply_rope(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> t
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
if USE_FLASH_ATTN3:
hidden_states = flash_attn_func(query, key, value, causal=False, deterministic=False)[0]
if flash_attn_available:
if USE_FLASH_ATTN3:
hidden_states = flash_attn_func(query, key, value, causal=False, deterministic=False)[0]
else:
hidden_states = flash_attn_func(query, key, value, dropout_p=0., causal=False)
else:
hidden_states = flash_attn_func(query, key, value, dropout_p=0., causal=False)
# Use torch's scaled dot-product attention as fallback
# Reshape for torch.nn.functional.scaled_dot_product_attention which expects [batch, heads, seq_len, head_dim]
query = query.transpose(1, 2) # [batch, heads, seq_len, head_dim]
key = key.transpose(1, 2)
value = value.transpose(1, 2)
hidden_states = torch.nn.functional.scaled_dot_product_attention(
query, key, value,
attn_mask=None,
dropout_p=0.0,
is_causal=False
)
# Restore original shape
hidden_states = hidden_states.transpose(1, 2) # [batch, seq_len, heads, head_dim]
hidden_states = hidden_states.flatten(-2)
hidden_states = hidden_states.to(query.dtype)
return hidden_states