mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Make flash attn optional. Handle larger batch sizes.
This commit is contained in:
@@ -58,7 +58,8 @@ RUN echo "Cache bust: ${CACHEBUST}" && \
|
||||
WORKDIR /app/ai-toolkit
|
||||
|
||||
# Install Python dependencies
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
RUN pip install --no-cache-dir -r requirements.txt && \
|
||||
pip install flash-attn --no-build-isolation --no-cache-dir
|
||||
|
||||
# Build UI
|
||||
WORKDIR /app/ai-toolkit/ui
|
||||
|
||||
@@ -129,14 +129,6 @@ class HidreamModel(BaseModel):
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
# count the params
|
||||
sd = transformer.state_dict()
|
||||
num_params = sum(p.numel() for p in sd.values())
|
||||
print(f"Number of params in transformer: {num_params}")
|
||||
# count params with name expert in them
|
||||
num_expert_params = sum(p.numel() for k, p in sd.items() if 'expert' in k)
|
||||
print(f"Number of params in transformer with expert in them: {num_expert_params}")
|
||||
|
||||
if not self.low_vram:
|
||||
transformer.to(self.device_torch, dtype=dtype)
|
||||
|
||||
|
||||
@@ -2,12 +2,20 @@ from typing import Optional
|
||||
import torch
|
||||
from .attention import HiDreamAttention
|
||||
|
||||
# Try to import Flash Attention first
|
||||
flash_attn_available = False
|
||||
try:
|
||||
from flash_attn_interface import flash_attn_func
|
||||
USE_FLASH_ATTN3 = True
|
||||
except:
|
||||
from flash_attn import flash_attn_func
|
||||
USE_FLASH_ATTN3 = False
|
||||
flash_attn_available = True
|
||||
except ImportError:
|
||||
try:
|
||||
from flash_attn import flash_attn_func
|
||||
USE_FLASH_ATTN3 = False
|
||||
flash_attn_available = True
|
||||
except ImportError:
|
||||
USE_FLASH_ATTN3 = False
|
||||
flash_attn_available = False
|
||||
|
||||
# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py
|
||||
def apply_rope(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -18,10 +26,28 @@ def apply_rope(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> t
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
|
||||
def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
|
||||
if USE_FLASH_ATTN3:
|
||||
hidden_states = flash_attn_func(query, key, value, causal=False, deterministic=False)[0]
|
||||
if flash_attn_available:
|
||||
if USE_FLASH_ATTN3:
|
||||
hidden_states = flash_attn_func(query, key, value, causal=False, deterministic=False)[0]
|
||||
else:
|
||||
hidden_states = flash_attn_func(query, key, value, dropout_p=0., causal=False)
|
||||
else:
|
||||
hidden_states = flash_attn_func(query, key, value, dropout_p=0., causal=False)
|
||||
# Use torch's scaled dot-product attention as fallback
|
||||
# Reshape for torch.nn.functional.scaled_dot_product_attention which expects [batch, heads, seq_len, head_dim]
|
||||
query = query.transpose(1, 2) # [batch, heads, seq_len, head_dim]
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
hidden_states = torch.nn.functional.scaled_dot_product_attention(
|
||||
query, key, value,
|
||||
attn_mask=None,
|
||||
dropout_p=0.0,
|
||||
is_causal=False
|
||||
)
|
||||
|
||||
# Restore original shape
|
||||
hidden_states = hidden_states.transpose(1, 2) # [batch, seq_len, heads, head_dim]
|
||||
|
||||
hidden_states = hidden_states.flatten(-2)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
return hidden_states
|
||||
|
||||
Reference in New Issue
Block a user