diff --git a/docker/Dockerfile b/docker/Dockerfile index 0d95b254..020f4c1d 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -58,7 +58,8 @@ RUN echo "Cache bust: ${CACHEBUST}" && \ WORKDIR /app/ai-toolkit # Install Python dependencies -RUN pip install --no-cache-dir -r requirements.txt +RUN pip install --no-cache-dir -r requirements.txt && \ + pip install flash-attn --no-build-isolation --no-cache-dir # Build UI WORKDIR /app/ai-toolkit/ui diff --git a/extensions_built_in/diffusion_models/hidream/hidream_model.py b/extensions_built_in/diffusion_models/hidream/hidream_model.py index eb02c644..3322a111 100644 --- a/extensions_built_in/diffusion_models/hidream/hidream_model.py +++ b/extensions_built_in/diffusion_models/hidream/hidream_model.py @@ -129,14 +129,6 @@ class HidreamModel(BaseModel): torch_dtype=torch.bfloat16 ) - # count the params - sd = transformer.state_dict() - num_params = sum(p.numel() for p in sd.values()) - print(f"Number of params in transformer: {num_params}") - # count params with name expert in them - num_expert_params = sum(p.numel() for k, p in sd.items() if 'expert' in k) - print(f"Number of params in transformer with expert in them: {num_expert_params}") - if not self.low_vram: transformer.to(self.device_torch, dtype=dtype) diff --git a/extensions_built_in/diffusion_models/hidream/src/models/attention_processor.py b/extensions_built_in/diffusion_models/hidream/src/models/attention_processor.py index a989896c..abc2ff40 100644 --- a/extensions_built_in/diffusion_models/hidream/src/models/attention_processor.py +++ b/extensions_built_in/diffusion_models/hidream/src/models/attention_processor.py @@ -2,12 +2,20 @@ from typing import Optional import torch from .attention import HiDreamAttention +# Try to import Flash Attention first +flash_attn_available = False try: from flash_attn_interface import flash_attn_func USE_FLASH_ATTN3 = True -except: - from flash_attn import flash_attn_func - USE_FLASH_ATTN3 = False + flash_attn_available = True +except ImportError: + try: + from flash_attn import flash_attn_func + USE_FLASH_ATTN3 = False + flash_attn_available = True + except ImportError: + USE_FLASH_ATTN3 = False + flash_attn_available = False # Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py def apply_rope(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: @@ -18,10 +26,28 @@ def apply_rope(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor) -> t return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor): - if USE_FLASH_ATTN3: - hidden_states = flash_attn_func(query, key, value, causal=False, deterministic=False)[0] + if flash_attn_available: + if USE_FLASH_ATTN3: + hidden_states = flash_attn_func(query, key, value, causal=False, deterministic=False)[0] + else: + hidden_states = flash_attn_func(query, key, value, dropout_p=0., causal=False) else: - hidden_states = flash_attn_func(query, key, value, dropout_p=0., causal=False) + # Use torch's scaled dot-product attention as fallback + # Reshape for torch.nn.functional.scaled_dot_product_attention which expects [batch, heads, seq_len, head_dim] + query = query.transpose(1, 2) # [batch, heads, seq_len, head_dim] + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + hidden_states = torch.nn.functional.scaled_dot_product_attention( + query, key, value, + attn_mask=None, + dropout_p=0.0, + is_causal=False + ) + + # Restore original shape + hidden_states = hidden_states.transpose(1, 2) # [batch, seq_len, heads, head_dim] + hidden_states = hidden_states.flatten(-2) hidden_states = hidden_states.to(query.dtype) return hidden_states