Files
composable_kernel/dispatcher/python/fmha_utils.py
Vidyasagar Ananthan 86591de476 [rocm-libraries] ROCm/rocm-libraries#5260 (commit a1834d2)
[CK] [CK_Tile] Add FMHA scaffolding to CK kernel dispatcher
 (#5260)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

The CK Tile dispatcher currently supports GEMM and Grouped Convolution
but has no support for Fused Multi-Head Attention (FMHA). The
example/ck_tile/01_fmha folder contains a comprehensive FMHA
implementation with forward, backward, split-KV, paged-KV, append-KV,
and batch-prefill kernels across multiple GPU architectures — but there
is no unified dispatch layer for it. This PR ports the FMHA stack into
the dispatcher, following the same architectural patterns established by
GEMM and Grouped Convolution, enabling runtime kernel selection, JIT
compilation from Python, and a declarative C++ example flow. Autotuning
heuristics to follow.

## Technical Details

This PR adds FMHA scaffolding to the CK dispatcher framework, mirroring
GEMM's layered architecture. Seven new C++ runtime headers provide type
definitions (coexisting with upstream headers via __has_include,
requiring zero modifications to example/ck_tile/01_fmha/), a problem
builder with 18+ setters, Signature + Algorithm kernel key matching, a
virtual kernel instance, a DECL_FMHA_KERNEL_SET macro with wildcard
support and named tile/wave/warp setters, arch-aware registry with JSON
export, and a dispatcher with seqtune-aware selection, configurable
timing, and multi-stage execution plans for split-KV (two-stage) and
backward (three-stage). The codegen pipeline is driven by a
fmha_arch_specs.json capturing per-arch tile tables and pipeline
constraints for five architectures (gfx90a/942/950/1100/1201), migrated
from hardcoded logic in 01_fmha/codegen/, with supporting modules for
C++ symbol mappings, validation rules, and named receipt profiles
(ck_default, flash, pytorch, aiter, fp32, fp8). Python integration
(fmha_utils.py) mirrors the C++ layer with JIT compilation, parallel
multi-kernel builds, HIP memory management via ctypes, tolerance-based
validation, and a NumPy CPU reference with GQA support. Twenty-seven C++
and thirty-two Python examples cover the full feature surface — forward,
split-KV, masks, bias, dropout, GQA, backward, append-KV, batch prefill,
fp8, logits soft cap, sink tokens, and parameter sweeps — all
JIT-compiled on the fly.

## Test Plan

Seven test files cover the runtime types, codegen, and end-to-end
correctness. C++ unit tests validate the problem builder, dispatcher
planning (single-stage for forward/paged-KV/append-KV; multi-stage for
split-KV and backward), registry operations, and the kernel-set
declaration macro. Python unit tests verify codegen emission, profile
filtering, and 15 validation rules for masks, hdim constraints, and
pipeline requirements. GPU execution validation in 01_basic_fmha
--validate reports zero errors across 65,536 elements with max absolute
error of 7.29e-05. A gold-standard parity suite (test_fmha_parity.py)
runs 14 configurations through both the upstream tile_example_fmha_fwd
and the dispatcher, comparing exit codes to confirm behavioral parity —
all 14 match.

## Test Result

The C++ smoke test builds and passes all 9 compiled examples, and a
Python JIT sweep (29_sweep_seqlen.py) passes 7/7 configurations reaching
up to 375 TFLOPS at seqlen 2048.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-05-17 07:30:33 +00:00

1843 lines
60 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
FMHA Dispatcher Python Utilities
Provides Python wrappers for FMHA dispatcher kernels via ctypes.
Mirrors ctypes_utils.py (GEMM) and grouped_conv_utils.py (Conv).
Usage:
from fmha_utils import FmhaDispatcherLib, FmhaRunner, FmhaProblem, cpu_attention_fwd
runner = FmhaRunner.from_prebuilt()
result = runner.run(Q, K, V, problem)
"""
import ctypes
import json
import os
import subprocess
import sys
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple
import numpy as np
# =============================================================================
# Utility helpers
# =============================================================================
try:
from dispatcher_common import detect_gpu_arch, get_dispatcher_root
except ImportError:
# Standalone usage without dispatcher_common on PYTHONPATH
def get_dispatcher_root() -> Path:
return Path(__file__).parent.parent
def detect_gpu_arch(fallback: str = "gfx950") -> str:
try:
out = subprocess.check_output(
["rocminfo"], text=True, stderr=subprocess.DEVNULL
)
for line in out.splitlines():
if "Name:" in line and "gfx" in line:
return line.split()[-1].strip()
except Exception:
pass
return fallback
# =============================================================================
# Data types
# =============================================================================
@dataclass
class FmhaResult:
success: bool
output: Optional[np.ndarray] = None
time_ms: float = 0.0
tflops: float = 0.0
error: str = ""
@dataclass
class FmhaProblem:
batch: int = 2
nhead_q: int = 8
nhead_k: int = 8
seqlen_q: int = 128
seqlen_k: int = 128
hdim_q: int = 128
hdim_v: int = 128
@property
def scale(self) -> float:
return 1.0 / (self.hdim_q**0.5)
@property
def num_ops(self) -> int:
sq, sk = self.seqlen_q, self.seqlen_k
return 2 * self.batch * self.nhead_q * sq * sk * (self.hdim_q + self.hdim_v)
def q_shape(self):
return (self.batch, self.nhead_q, self.seqlen_q, self.hdim_q)
def k_shape(self):
return (self.batch, self.nhead_k, self.seqlen_k, self.hdim_q)
def v_shape(self):
return (self.batch, self.nhead_k, self.seqlen_k, self.hdim_v)
def o_shape(self):
return (self.batch, self.nhead_q, self.seqlen_q, self.hdim_v)
@dataclass
class FmhaKernelConfig:
"""Complete kernel configuration for FMHA.
All tile/wave/warp dimensions are explicitly named to match the
GEMM pattern (tile_m, tile_n, tile_k) but extended for FMHA's
two-stage computation (Q*K^T stage 0, Attn*V stage 1).
"""
# -- Signature: what operation --
family: str = "fwd"
data_type: str = "fp16"
mode: str = "batch"
vlayout: str = "r"
hdim_q: int = 128
hdim_v: int = 128
gfx_arch: str = "gfx950"
# -- Algorithm: tile shape --
# Stage 0 (Q * K^T): seqlen_q x seqlen_k x hdim_q
tile_m0: int = 128 # seqlen_q tile
tile_n0: int = 128 # seqlen_k tile
tile_k0: int = 32 # hdim_q tile
# Stage 1 (Attn * V): seqlen_q x hdim_v x seqlen_k
tile_n1: int = 128 # hdim_v tile
tile_k1: int = 32 # seqlen_k tile
tile_k0max: int = 128 # max k0 (alignment)
# BWD extra stages (9-element tile)
tile_bwd6: int = 0
tile_bwd7: int = 0
tile_bwd8: int = 0
# -- Algorithm: wave config (warps per block) --
wave_m0: int = 4
wave_n0: int = 1
wave_k0: int = 1
wave_m1: int = 4
wave_n1: int = 1
wave_k1: int = 1
wave_m2: int = 1
wave_n2: int = 1
wave_k2: int = 1
# -- Algorithm: warp tile (elements per warp) --
warp_m0: int = 32
warp_n0: int = 32
warp_k0: int = 16
warp_m1: int = 32
warp_n1: int = 32
warp_k1: int = 16
warp_m2: int = 16
warp_n2: int = 16
warp_k2: int = 16
# -- Algorithm: padding --
# Values: 0=no pad, 1=pad, 8=pad with 8-byte alignment (BWD-specific)
pad_s: int = 1
pad_sk: int = 1
pad_d: int = 1
pad_dv: int = 1
# -- Algorithm: pipeline --
pipeline: str = "qr_async"
block_per_cu: int = -1
num_wave_groups: int = 1
# -- Signature: features --
mask: str = "no"
bias: str = "no"
lse: bool = False
dropout: bool = False
qscale: str = "no"
rope: str = "none"
logits: bool = False
paged_kv: bool = False
sink: bool = False
skip_min_seqlen_q: bool = False
page_size: int = 1
kv_memory_layout: str = "vectorized"
kv_lookup_table: str = "sglang"
deterministic: bool = False
dbias: bool = False
dropout_variant: str = "" # BWD: "no"/"dropout_wg16"/"dropout_wg16_storerandval"
tile_tag: str = "" # extra tile variant discriminator (e.g. "trload", "small")
use_trload: bool = False # BWD dq_dk_dv: use trload pipeline path
@property
def tile(self) -> Tuple[int, ...]:
base = (
self.tile_m0,
self.tile_n0,
self.tile_k0,
self.tile_n1,
self.tile_k1,
self.tile_k0max,
)
if self.family == "bwd_dq_dk_dv" and self.tile_bwd6 > 0:
return base + (self.tile_bwd6, self.tile_bwd7, self.tile_bwd8)
return base
@property
def wave(self) -> Tuple[int, ...]:
return (
self.wave_m0,
self.wave_n0,
self.wave_k0,
self.wave_m1,
self.wave_n1,
self.wave_k1,
self.wave_m2,
self.wave_n2,
self.wave_k2,
)
@property
def warp(self) -> Tuple[int, ...]:
return (
self.warp_m0,
self.warp_n0,
self.warp_k0,
self.warp_m1,
self.warp_n1,
self.warp_k1,
self.warp_m2,
self.warp_n2,
self.warp_k2,
)
@property
def padding(self) -> Tuple[bool, ...]:
return (self.pad_s, self.pad_sk, self.pad_d, self.pad_dv)
@property
def name(self) -> str:
s = self.pad_s
k = self.pad_sk
d = self.pad_d
v = self.pad_dv
parts = [
f"fmha_{self.family}_{self.data_type}",
self.mode,
f"h{self.hdim_q}x{self.hdim_v}"
if self.hdim_q != self.hdim_v
else f"h{self.hdim_q}",
self.pipeline,
f"t{self.tile_m0}x{self.tile_n0}x{self.tile_k0}x{self.tile_n1}x{self.tile_k1}x{self.tile_k0max}"
+ (f".{self.tile_tag}" if self.tile_tag else ""),
]
# Always include warp class for uniform naming
parts.append(f"w{self.warp_m0}x{self.warp_n0}x{self.warp_k0}")
parts.extend(
[
f"pad{s}{k}{d}{v}",
f"mask={self.mask}",
f"bias={self.bias}",
]
)
if self.lse:
parts.append("lse=1")
if self.dropout:
parts.append("drop=1")
if self.logits:
parts.append("logits=1")
if self.sink:
parts.append("sink=1")
if self.skip_min_seqlen_q:
parts.append("skip=1")
if self.qscale != "no":
parts.append(f"qs={self.qscale}")
if self.paged_kv:
parts.append("pkv=1")
if self.rope != "none":
parts.append(f"rope={self.rope}")
if self.page_size != 1:
parts.append(f"ps={self.page_size}")
if self.kv_memory_layout != "vectorized":
parts.append(f"kvl={self.kv_memory_layout}")
if self.kv_lookup_table != "sglang":
parts.append(f"kvt={self.kv_lookup_table}")
if self.deterministic:
parts.append("det=1")
if self.dbias:
parts.append("dbias=1")
if self.dropout_variant and self.dropout_variant != "no":
parts.append(f"drv={self.dropout_variant}")
# Always include block_per_cu for uniform naming
parts.append(f"bpc={self.block_per_cu}")
return "_".join(parts)
def to_codegen_json(self) -> str:
return json.dumps(
{
"arch": self.gfx_arch,
"signature": {
"family": self.family,
"data_type": self.data_type,
"mode": self.mode,
"vlayout": self.vlayout,
"hdim_q": self.hdim_q,
"hdim_v": self.hdim_v,
"mask": self.mask,
"bias": self.bias,
"lse": self.lse,
"dropout": self.dropout,
"qscale": self.qscale,
"rope": self.rope,
"logits": self.logits,
"paged_kv": self.paged_kv,
"fp8_static_quant": False,
"skip_min_seqlen_q": self.skip_min_seqlen_q,
"sink": self.sink,
"dbias": self.dbias,
"store_randval": "storerandval" in self.dropout_variant,
"deterministic": self.deterministic,
"dropout_variant": self.dropout_variant,
"kv_memory_layout": self.kv_memory_layout,
"kv_lookup_table": self.kv_lookup_table,
"page_size": self.page_size,
},
"algorithm": {
"pipeline": self.pipeline,
"tile": list(self.tile),
"wave": list(self.wave),
"warp": list(self.warp),
"padding": list(self.padding),
"block_per_cu": self.block_per_cu,
"num_wave_groups": self.num_wave_groups,
"max_splits_log2": 0,
"max_seq_len_q": 0,
"use_trload": self.use_trload,
},
}
)
# =============================================================================
# CPU reference
# =============================================================================
def _float32_to_bf16(arr: np.ndarray) -> np.ndarray:
"""Convert float32 array to bf16 stored as uint16 (truncate lower 16 bits)."""
return arr.astype(np.float32).view(np.uint32).__rshift__(16).astype(np.uint16)
def _bf16_to_float32(arr: np.ndarray) -> np.ndarray:
"""Convert bf16 (uint16) array back to float32."""
return (arr.astype(np.uint32) << 16).view(np.float32)
def cpu_attention_fwd(
Q: np.ndarray,
K: np.ndarray,
V: np.ndarray,
scale: float,
mask_type: int = 0,
) -> np.ndarray:
"""CPU reference: scaled dot-product attention (supports GQA and causal mask).
Args:
Q: [batch, nhead_q, seqlen_q, hdim_q] float32
K: [batch, nhead_k, seqlen_k, hdim_q] float32
V: [batch, nhead_k, seqlen_k, hdim_v] float32
mask_type: 0=no mask, 1=causal top-left, 2=causal bottom-right
Returns:
O: [batch, nhead_q, seqlen_q, hdim_v] float32
"""
nhead_q = Q.shape[1]
nhead_k = K.shape[1]
if nhead_q != nhead_k:
ratio = nhead_q // nhead_k
K = np.repeat(K, ratio, axis=1)
V = np.repeat(V, ratio, axis=1)
S = np.matmul(Q, K.transpose(0, 1, 3, 2)) * scale
if mask_type in (1, 2):
sq, sk = S.shape[-2], S.shape[-1]
row = np.arange(sq).reshape(sq, 1)
col = np.arange(sk).reshape(1, sk)
if mask_type == 1: # top-left causal
causal_mask = col <= row
else: # bottom-right causal
causal_mask = col <= (row + sk - sq)
S = np.where(causal_mask, S, -1e9)
S_max = S.max(axis=-1, keepdims=True)
S_exp = np.exp(S - S_max)
P = S_exp / S_exp.sum(axis=-1, keepdims=True)
return np.matmul(P, V)
def cpu_attention_fwd_with_intermediates(
Q: np.ndarray, K: np.ndarray, V: np.ndarray, scale: float
) -> tuple:
"""CPU reference forward returning (output, P) for backward use.
Same as cpu_attention_fwd but also returns the softmax probability matrix P.
"""
nhead_q = Q.shape[1]
nhead_k = K.shape[1]
if nhead_q != nhead_k:
ratio = nhead_q // nhead_k
K = np.repeat(K, ratio, axis=1)
V = np.repeat(V, ratio, axis=1)
S = np.matmul(Q, K.transpose(0, 1, 3, 2)) * scale
S_max = S.max(axis=-1, keepdims=True)
S_exp = np.exp(S - S_max)
P = S_exp / S_exp.sum(axis=-1, keepdims=True)
out = np.matmul(P, V)
return out, P
def cpu_attention_bwd(
Q: np.ndarray,
K: np.ndarray,
V: np.ndarray,
out: np.ndarray,
dO: np.ndarray,
P: np.ndarray,
scale: float,
) -> tuple:
"""CPU reference backward. Returns (dQ, dK, dV).
Args:
Q, K, V: forward inputs [batch, heads, seq, dim]
out: forward output
dO: gradient of output
P: softmax probabilities from forward
scale: attention scale factor
"""
D = (dO * out).sum(axis=-1, keepdims=True)
dP = np.matmul(dO, V.transpose(0, 1, 3, 2))
dS = P * (dP - D)
dQ = np.matmul(dS, K) * scale
dK = np.matmul(dS.transpose(0, 1, 3, 2), Q) * scale
dV = np.matmul(P.transpose(0, 1, 3, 2), dO)
return dQ, dK, dV
# =============================================================================
# Low-level ctypes wrapper
# =============================================================================
class FmhaDispatcherLib:
"""Wrapper for the FMHA dispatcher shared library (libdispatcher_fmha_lib.so)."""
SEARCH_PATHS = [
"build/examples/libdispatcher_fmha_lib.so",
"build/libdispatcher_fmha_lib.so",
"build/lib/libdispatcher_fmha_lib.so",
]
def __init__(self, lib: ctypes.CDLL, path: Path):
self._lib = lib
self.path = path
self._setup()
def _setup(self):
lib = self._lib
lib.fmha_dispatcher_initialize.argtypes = [ctypes.c_char_p]
lib.fmha_dispatcher_initialize.restype = ctypes.c_int
lib.fmha_dispatcher_run_fwd.argtypes = [
ctypes.c_void_p, # q
ctypes.c_void_p, # k
ctypes.c_void_p, # v
ctypes.c_void_p, # o
ctypes.c_int, # batch
ctypes.c_int, # nhead_q
ctypes.c_int, # nhead_k
ctypes.c_int, # seqlen_q
ctypes.c_int, # seqlen_k
ctypes.c_int, # hdim_q
ctypes.c_int, # hdim_v
ctypes.c_float, # scale
ctypes.c_int, # mask_type
ctypes.c_int, # bias_type
ctypes.c_int, # has_lse
ctypes.c_int, # has_dropout
ctypes.c_int, # traits_hdim_q (0=same as hdim_q)
ctypes.c_int, # traits_hdim_v (0=same as hdim_v)
ctypes.c_int, # is_v_rowmajor (1=row, 0=col)
ctypes.c_int, # perm (1=BHSD, 0=BSHD)
ctypes.c_char_p, # data_type ("fp16", "bf16")
ctypes.c_int, # is_group_mode
ctypes.c_int, # window_left (-1=no window)
ctypes.c_int, # window_right (-1=no window, 0=causal)
ctypes.c_int, # has_logits
ctypes.c_int, # has_sink
ctypes.c_int, # has_skip
ctypes.POINTER(ctypes.c_float), # time_ms_out
]
lib.fmha_dispatcher_run_fwd.restype = ctypes.c_int
lib.fmha_dispatcher_run_bwd.argtypes = [
ctypes.c_void_p, # q
ctypes.c_void_p, # k
ctypes.c_void_p, # v
ctypes.c_void_p, # o
ctypes.c_void_p, # lse
ctypes.c_void_p, # do
ctypes.c_void_p, # dq
ctypes.c_void_p, # dk
ctypes.c_void_p, # dv
ctypes.c_int, # batch
ctypes.c_int, # nhead_q
ctypes.c_int, # nhead_k
ctypes.c_int, # seqlen_q
ctypes.c_int, # seqlen_k
ctypes.c_int, # hdim_q
ctypes.c_int, # hdim_v
ctypes.c_float, # scale
ctypes.c_char_p, # data_type_str
ctypes.c_int, # mask_type_int
ctypes.c_int, # bias_type_int
ctypes.c_int, # has_dropout
ctypes.c_int, # has_dbias
ctypes.c_int, # is_deterministic
ctypes.c_int, # is_group_mode
ctypes.c_int, # is_store_randval
ctypes.c_int, # tile_n0 (kN0 for nsplits computation)
ctypes.POINTER(ctypes.c_float), # time_ms_out
]
lib.fmha_dispatcher_run_bwd.restype = ctypes.c_int
# Split-KV forward
lib.fmha_dispatcher_run_splitkv.argtypes = [
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_float,
ctypes.c_int, # mask_type
ctypes.c_int, # num_splits
ctypes.c_int, # is_v_rowmajor
ctypes.c_char_p,
ctypes.c_int, # has_lse
ctypes.c_int, # is_group_mode
ctypes.c_int, # perm
ctypes.c_int, # has_logits
ctypes.c_int, # bias_type
ctypes.c_int, # has_sink
ctypes.c_int, # paged_kv
ctypes.c_int, # page_block_size
ctypes.c_int, # window_left
ctypes.c_int, # window_right
ctypes.POINTER(ctypes.c_float),
]
lib.fmha_dispatcher_run_splitkv.restype = ctypes.c_int
# Paged-KV forward
lib.fmha_dispatcher_run_pagedkv.argtypes = [
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_float,
ctypes.c_int, # mask_type
ctypes.c_int, # page_block_size
ctypes.c_int, # is_v_rowmajor
ctypes.c_char_p,
ctypes.c_int, # has_lse
ctypes.c_int, # has_logits
ctypes.c_int, # has_sink
ctypes.c_int, # skip_min_seqlen_q
ctypes.c_int, # bias_type
ctypes.POINTER(ctypes.c_float),
]
lib.fmha_dispatcher_run_pagedkv.restype = ctypes.c_int
# Append-KV
lib.fmha_dispatcher_run_appendkv.argtypes = [
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int, # is_v_rowmajor
ctypes.c_int, # rope_type
ctypes.c_int, # paged_kv
ctypes.c_int, # page_block_size
ctypes.c_char_p,
ctypes.POINTER(ctypes.c_float),
]
lib.fmha_dispatcher_run_appendkv.restype = ctypes.c_int
# Batch Prefill
lib.fmha_dispatcher_run_batch_prefill.argtypes = [
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_int,
ctypes.c_float,
ctypes.c_int, # mask_type
ctypes.c_int, # bias_type
ctypes.c_int, # page_block_size
ctypes.c_int, # kv_layout_int
ctypes.c_int, # kv_lookup_int
ctypes.c_int, # is_v_rowmajor
ctypes.c_char_p,
ctypes.c_int, # has_lse
ctypes.c_int, # has_dropout
ctypes.c_int, # has_logits
ctypes.c_int, # has_sink
ctypes.c_int, # skip_min_seqlen_q
ctypes.POINTER(ctypes.c_float),
]
lib.fmha_dispatcher_run_batch_prefill.restype = ctypes.c_int
lib.fmha_dispatcher_kernel_count.argtypes = []
lib.fmha_dispatcher_kernel_count.restype = ctypes.c_int
lib.fmha_dispatcher_cleanup.argtypes = []
lib.fmha_dispatcher_cleanup.restype = None
@classmethod
def find(cls) -> Optional["FmhaDispatcherLib"]:
root = get_dispatcher_root()
for rel in cls.SEARCH_PATHS:
path = root / rel
if path.exists():
try:
lib = ctypes.CDLL(str(path))
return cls(lib, path)
except OSError:
continue
return None
@classmethod
def load(cls, path: str) -> "FmhaDispatcherLib":
lib = ctypes.CDLL(path)
return cls(lib, Path(path))
def initialize(self, arch: str = "gfx950") -> bool:
return self._lib.fmha_dispatcher_initialize(arch.encode()) == 0
def run_bwd(
self,
q: ctypes.c_void_p,
k: ctypes.c_void_p,
v: ctypes.c_void_p,
o: ctypes.c_void_p,
lse: ctypes.c_void_p,
do_grad: ctypes.c_void_p,
dq: ctypes.c_void_p,
dk: ctypes.c_void_p,
dv: ctypes.c_void_p,
prob: FmhaProblem,
data_type: str = "fp16",
mask_type: int = 0,
bias_type: int = 0,
has_dropout: bool = False,
has_dbias: bool = False,
is_deterministic: bool = False,
is_group_mode: bool = False,
is_store_randval: bool = False,
tile_n0: int = 128,
) -> Tuple[int, float]:
time_ms = ctypes.c_float(0.0)
rc = self._lib.fmha_dispatcher_run_bwd(
q,
k,
v,
o,
lse,
do_grad,
dq,
dk,
dv,
prob.batch,
prob.nhead_q,
prob.nhead_k,
prob.seqlen_q,
prob.seqlen_k,
prob.hdim_q,
prob.hdim_v,
prob.scale,
data_type.encode(),
ctypes.c_int(mask_type),
ctypes.c_int(bias_type),
ctypes.c_int(int(has_dropout)),
ctypes.c_int(int(has_dbias)),
ctypes.c_int(int(is_deterministic)),
ctypes.c_int(int(is_group_mode)),
ctypes.c_int(int(is_store_randval)),
ctypes.c_int(tile_n0),
ctypes.byref(time_ms),
)
return rc, time_ms.value
def kernel_count(self) -> int:
return self._lib.fmha_dispatcher_kernel_count()
def cleanup(self):
self._lib.fmha_dispatcher_cleanup()
# =============================================================================
# High-level GPU runner (mirrors GpuGroupedConvRunner)
# =============================================================================
class FmhaRunner:
"""High-level FMHA runner with NumPy interface and HIP memory management."""
HIP_MEMCPY_H2D = 1
HIP_MEMCPY_D2H = 2
def __init__(self, dispatch_lib: FmhaDispatcherLib, arch: str = "gfx950"):
self._lib = dispatch_lib
self._arch = arch
self._hip = None
self._load_hip()
if not dispatch_lib.initialize(arch):
raise RuntimeError("Failed to initialize FMHA dispatcher")
def _load_hip(self):
for name in ["libamdhip64.so", "libamdhip64.so.6"]:
try:
self._hip = ctypes.CDLL(name)
self._hip.hipMalloc.argtypes = [
ctypes.POINTER(ctypes.c_void_p),
ctypes.c_size_t,
]
self._hip.hipMalloc.restype = ctypes.c_int
self._hip.hipFree.argtypes = [ctypes.c_void_p]
self._hip.hipFree.restype = ctypes.c_int
self._hip.hipMemcpy.argtypes = [
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_size_t,
ctypes.c_int,
]
self._hip.hipMemcpy.restype = ctypes.c_int
self._hip.hipMemset.argtypes = [
ctypes.c_void_p,
ctypes.c_int,
ctypes.c_size_t,
]
self._hip.hipMemset.restype = ctypes.c_int
return
except OSError:
continue
raise RuntimeError("Could not load libamdhip64.so")
@classmethod
def from_prebuilt(cls, arch: Optional[str] = None) -> "FmhaRunner":
arch = arch or detect_gpu_arch()
lib = FmhaDispatcherLib.find()
if lib is None:
raise RuntimeError(
"FMHA dispatcher library not found. Build with:\n"
" cd dispatcher/build && cmake .. -DBUILD_DISPATCHER_EXAMPLES=ON && make dispatcher_fmha_lib"
)
return cls(lib, arch)
@classmethod
def from_library(cls, path: str, arch: Optional[str] = None) -> "FmhaRunner":
arch = arch or detect_gpu_arch()
return cls(FmhaDispatcherLib.load(path), arch)
def run(
self,
Q: np.ndarray,
K: np.ndarray,
V: np.ndarray,
prob: FmhaProblem,
mask_type: int = 0,
bias_type: int = 0,
has_lse: int = 0,
has_dropout: int = 0,
has_logits: int = 0,
has_sink: int = 0,
has_skip: int = 0,
api_family: str = "fwd",
data_type: str = "fp16",
**kwargs,
) -> "FmhaResult":
"""Run FMHA forward on GPU with automatic HIP memory management.
Args:
Q: [batch, nhead_q, seqlen_q, hdim_q] float16
K: [batch, nhead_k, seqlen_k, hdim_q] float16
V: [batch, nhead_k, seqlen_k, hdim_v] float16
Returns:
FmhaResult with output array, timing, TFLOPS
"""
# Map CK dtype to numpy dtype for buffer allocation.
# bf16 is stored as uint16 (upper 16 bits of float32).
# fp8 uses uint8 (1 byte per element).
_NP_DTYPE = {
"fp16": np.float16,
"bf16": np.uint16,
"fp32": np.float32,
"fp8bf16": np.uint8,
"fp8fp32": np.uint8,
"bf8": np.uint8,
}
_NP_OUT_DTYPE = {
"fp16": np.float16,
"bf16": np.uint16,
"fp32": np.float32,
"fp8bf16": np.float16,
"fp8fp32": np.float32,
"bf8": np.uint8,
}
in_dt = _NP_DTYPE.get(data_type, np.float16)
out_dt = _NP_OUT_DTYPE.get(data_type, np.float16)
if data_type == "bf16":
Q_c = _float32_to_bf16(np.ascontiguousarray(Q.astype(np.float32)))
K_c = _float32_to_bf16(np.ascontiguousarray(K.astype(np.float32)))
V_c = _float32_to_bf16(np.ascontiguousarray(V.astype(np.float32)))
else:
Q_c = np.ascontiguousarray(Q.astype(in_dt))
K_c = np.ascontiguousarray(K.astype(in_dt))
V_c = np.ascontiguousarray(V.astype(in_dt))
O_c = np.zeros(prob.o_shape(), dtype=out_dt)
d_q, d_k, d_v, d_o = (ctypes.c_void_p() for _ in range(4))
try:
self._hip.hipMalloc(ctypes.byref(d_q), Q_c.nbytes)
self._hip.hipMalloc(ctypes.byref(d_k), K_c.nbytes)
self._hip.hipMalloc(ctypes.byref(d_v), V_c.nbytes)
self._hip.hipMalloc(ctypes.byref(d_o), O_c.nbytes)
self._hip.hipMemcpy(d_q, Q_c.ctypes.data, Q_c.nbytes, self.HIP_MEMCPY_H2D)
self._hip.hipMemcpy(d_k, K_c.ctypes.data, K_c.nbytes, self.HIP_MEMCPY_H2D)
self._hip.hipMemcpy(d_v, V_c.ctypes.data, V_c.nbytes, self.HIP_MEMCPY_H2D)
self._hip.hipMemset(d_o, 0, O_c.nbytes)
time_ms = ctypes.c_float(0.0)
lib = self._lib._lib
is_v_rowmajor = kwargs.get("is_v_rowmajor", 1)
is_group_mode = kwargs.get("is_group_mode", 0)
perm = kwargs.get("perm", 1)
window_left = kwargs.get("window_left", -1)
window_right = kwargs.get("window_right", -1)
num_splits = kwargs.get("num_splits", 4)
page_size = kwargs.get("page_size", 64)
kv_layout = kwargs.get("kv_layout", 0)
kv_lookup = kwargs.get("kv_lookup", 0)
traits_hdim_q = kwargs.get("traits_hdim_q", 0)
traits_hdim_v = kwargs.get("traits_hdim_v", 0)
if api_family == "splitkv":
paged_kv = kwargs.get("paged_kv", 0)
rc = lib.fmha_dispatcher_run_splitkv(
d_q,
d_k,
d_v,
d_o,
prob.batch,
prob.nhead_q,
prob.nhead_k,
prob.seqlen_q,
prob.seqlen_k,
prob.hdim_q,
prob.hdim_v,
prob.scale,
mask_type,
num_splits,
is_v_rowmajor,
data_type.encode(),
has_lse,
is_group_mode,
perm,
has_logits,
bias_type,
has_sink,
paged_kv,
page_size,
window_left,
window_right,
ctypes.byref(time_ms),
)
elif api_family == "pagedkv":
rc = lib.fmha_dispatcher_run_pagedkv(
d_q,
d_k,
d_v,
d_o,
prob.batch,
prob.nhead_q,
prob.nhead_k,
prob.seqlen_q,
prob.seqlen_k,
prob.hdim_q,
prob.hdim_v,
prob.scale,
mask_type,
page_size,
is_v_rowmajor,
data_type.encode(),
has_lse,
has_logits,
has_sink,
has_skip,
bias_type,
ctypes.byref(time_ms),
)
elif api_family == "appendkv":
seqlen_knew = kwargs.get("seqlen_knew", prob.seqlen_k)
rc = lib.fmha_dispatcher_run_appendkv(
Q_c.ctypes.data,
K_c.ctypes.data,
V_c.ctypes.data,
prob.batch,
prob.nhead_q,
prob.nhead_k,
prob.seqlen_q,
seqlen_knew,
prob.hdim_q,
prob.hdim_v,
is_v_rowmajor,
kwargs.get("rope_type", 0),
kwargs.get("paged_kv", 0),
page_size,
data_type.encode(),
ctypes.byref(time_ms),
)
elif api_family == "batch_prefill":
skip_min_sq = kwargs.get("skip_min_seqlen_q", 0)
rc = lib.fmha_dispatcher_run_batch_prefill(
d_q,
d_k,
d_v,
d_o,
prob.batch,
prob.nhead_q,
prob.nhead_k,
prob.seqlen_q,
prob.seqlen_k,
prob.hdim_q,
prob.hdim_v,
prob.scale,
mask_type,
bias_type,
page_size,
kv_layout,
kv_lookup,
is_v_rowmajor,
data_type.encode(),
has_lse,
has_dropout,
has_logits,
has_sink,
skip_min_sq,
ctypes.byref(time_ms),
)
else:
rc = lib.fmha_dispatcher_run_fwd(
d_q,
d_k,
d_v,
d_o,
prob.batch,
prob.nhead_q,
prob.nhead_k,
prob.seqlen_q,
prob.seqlen_k,
prob.hdim_q,
prob.hdim_v,
prob.scale,
mask_type,
bias_type,
has_lse,
has_dropout,
traits_hdim_q,
traits_hdim_v,
is_v_rowmajor,
perm,
data_type.encode(),
is_group_mode,
window_left,
window_right,
has_logits,
has_sink,
has_skip,
ctypes.byref(time_ms),
)
if rc != 0:
return FmhaResult(success=False, error=f"Kernel failed (rc={rc})")
self._hip.hipMemcpy(O_c.ctypes.data, d_o, O_c.nbytes, self.HIP_MEMCPY_D2H)
# Convert bf16 output (uint16) back to float32 for comparison
if data_type == "bf16":
O_c = _bf16_to_float32(O_c)
# appendkv is a memory op (KV cache copy), not compute -- no TFLOPS
ops = 0 if api_family == "appendkv" else prob.num_ops
tflops = (
ops / (time_ms.value * 1e-3) / 1e12
if time_ms.value > 0 and ops > 0
else 0.0
)
return FmhaResult(
success=True, output=O_c, time_ms=time_ms.value, tflops=tflops
)
finally:
for d in [d_q, d_k, d_v, d_o]:
if d.value:
self._hip.hipFree(d)
def run_bwd(
self,
Q: np.ndarray,
K: np.ndarray,
V: np.ndarray,
out: np.ndarray,
LSE: np.ndarray,
dO: np.ndarray,
prob: FmhaProblem,
data_type: str = "fp16",
mask_type: int = 0,
bias_type: int = 0,
has_dropout: bool = False,
has_dbias: bool = False,
is_deterministic: bool = False,
is_group_mode: bool = False,
is_store_randval: bool = False,
tile_n0: int = 128,
) -> "FmhaResult":
"""Run FMHA backward on GPU with automatic HIP memory management.
Returns FmhaResult with dQ, dK, dV packed in output as a tuple.
"""
_NP_DTYPE = {
"fp16": np.float16,
"bf16": np.float16,
"fp32": np.float32,
"fp8bf16": np.uint8,
"fp8fp32": np.uint8,
"bf8": np.uint8,
}
in_dt = _NP_DTYPE.get(data_type, np.float16)
Q_c = np.ascontiguousarray(Q.astype(in_dt))
K_c = np.ascontiguousarray(K.astype(in_dt))
V_c = np.ascontiguousarray(V.astype(in_dt))
O_c = np.ascontiguousarray(out.astype(in_dt))
LSE_c = np.ascontiguousarray(LSE.astype(np.float32))
dO_c = np.ascontiguousarray(dO.astype(in_dt))
dQ_c = np.zeros_like(Q_c)
dK_c = np.zeros_like(K_c)
dV_c = np.zeros_like(V_c)
ptrs = [ctypes.c_void_p() for _ in range(9)]
d_q, d_k, d_v, d_o, d_lse, d_do, d_dq, d_dk, d_dv = ptrs
try:
for d, arr in zip(ptrs[:6], [Q_c, K_c, V_c, O_c, LSE_c, dO_c]):
self._hip.hipMalloc(ctypes.byref(d), arr.nbytes)
self._hip.hipMemcpy(d, arr.ctypes.data, arr.nbytes, self.HIP_MEMCPY_H2D)
for d, arr in zip(ptrs[6:], [dQ_c, dK_c, dV_c]):
self._hip.hipMalloc(ctypes.byref(d), arr.nbytes)
self._hip.hipMemset(d, 0, arr.nbytes)
rc, elapsed = self._lib.run_bwd(
d_q,
d_k,
d_v,
d_o,
d_lse,
d_do,
d_dq,
d_dk,
d_dv,
prob,
data_type,
mask_type=mask_type,
bias_type=bias_type,
has_dropout=has_dropout,
has_dbias=has_dbias,
is_deterministic=is_deterministic,
is_group_mode=is_group_mode,
is_store_randval=is_store_randval,
tile_n0=tile_n0,
)
if rc != 0:
return FmhaResult(success=False, error=f"BWD kernel failed (rc={rc})")
for d, arr in zip(ptrs[6:], [dQ_c, dK_c, dV_c]):
self._hip.hipMemcpy(arr.ctypes.data, d, arr.nbytes, self.HIP_MEMCPY_D2H)
tflops = prob.num_ops / (elapsed * 1e-3) / 1e12 if elapsed > 0 else 0.0
return FmhaResult(
success=True,
output=(dQ_c, dK_c, dV_c),
time_ms=elapsed,
tflops=tflops,
)
finally:
for d in ptrs:
if d.value:
self._hip.hipFree(d)
@property
def kernel_count(self) -> int:
return self._lib.kernel_count()
@property
def library_path(self) -> str:
return str(self._lib.path)
def cleanup(self):
self._lib.cleanup()
# =============================================================================
# JIT Build Support (mirrors setup_multiple_gemm_dispatchers)
# =============================================================================
@dataclass
class FmhaSetupResult:
success: bool
config: Optional[FmhaKernelConfig] = None
runner: Optional[FmhaRunner] = None
library_path: str = ""
error: str = ""
build_time_s: float = 0.0
def _build_static_lib(root: Path) -> Optional[Path]:
"""Build libck_tile_dispatcher.a via cmake if not already present."""
build_dir = root / "build"
build_dir.mkdir(parents=True, exist_ok=True)
hipcc = _find_hipcc()
cmake_cmd = ["cmake", str(root), f"-DCMAKE_CXX_COMPILER={hipcc}"]
r = subprocess.run(cmake_cmd, cwd=str(build_dir), capture_output=True, text=True)
if r.returncode != 0:
print(
f"Warning: cmake failed for dispatcher lib: {r.stderr[:200]}",
file=sys.stderr,
)
return None
make_cmd = ["make", "ck_tile_dispatcher", f"-j{os.cpu_count() or 4}"]
r = subprocess.run(make_cmd, cwd=str(build_dir), capture_output=True, text=True)
if r.returncode != 0:
print(
f"Warning: make failed for dispatcher lib: {r.stderr[:200]}",
file=sys.stderr,
)
return None
lib_path = build_dir / "libck_tile_dispatcher.a"
return lib_path if lib_path.exists() else None
def _find_static_lib() -> Optional[Path]:
root = get_dispatcher_root()
for rel in ["build/libck_tile_dispatcher.a", "build/lib/libck_tile_dispatcher.a"]:
p = root / rel
if p.exists():
return p
# Auto-build if not found
print(" Building libck_tile_dispatcher.a (first time)...", file=sys.stderr)
return _build_static_lib(root)
def _find_hipcc() -> str:
for path in ["/opt/rocm/bin/hipcc", "/usr/bin/hipcc"]:
if os.path.exists(path):
return path
return "hipcc"
def fmha_compile_flags(arch: str, hipcc: str = "", family: str = "") -> List[str]:
"""Base hipcc flags for compiling FMHA kernels. Shared by JIT and tile engine.
Source: example/ck_tile/01_fmha/CMakeLists.txt — mirrors CK's own build
flags to ensure parity. Key defines:
- CK_TILE_FMHA_FWD_FAST_EXP2: enables fast exp2 on gfx9 (CDNA)
- CK_TILE_USE_OCP_FP8: uses OCP standard fp8 format
- CK_GFX950_SUPPORT / CK_USE_GFX950: enables gfx950-specific code paths
- CK_USE_XDL: enables MFMA (matrix fused multiply-add) instructions
- CK_TILE_USE_WMMA: 0 for CDNA (uses MFMA instead)
- CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3: BWD bf16 conversion mode
"""
if not hipcc:
hipcc = _find_hipcc()
root = get_dispatcher_root()
flags = [
hipcc,
"-c",
"-fPIC",
"-O3",
"-DNDEBUG",
f"--offload-arch={arch}",
"-std=c++17",
f"-I{root.parent / 'include'}",
f"-I{root / 'include'}",
f"-I{root.parent}",
"-Wno-undefined-func-template",
"-Wno-float-equal",
"-fgpu-flush-denormals-to-zero",
"-fno-offload-uniform-block",
"-mllvm",
"--lsr-drop-solution=1",
"-mllvm",
"-enable-post-misched=0",
"-mllvm",
"-amdgpu-early-inline-all=true",
"-mllvm",
"-amdgpu-function-calls=false",
]
if arch.startswith("gfx9"):
flags.append("-DCK_TILE_FMHA_FWD_FAST_EXP2=1")
flags.append("-DCK_TILE_USE_OCP_FP8")
flags.append("-DCK_GFX950_SUPPORT")
flags.append("-DCK_USE_GFX950")
flags.append("-DCK_USE_GFX94")
flags.append("-DCK_USE_XDL")
flags.append("-DCK_TILE_USE_WMMA=0")
else:
flags.append("-DCK_TILE_FMHA_FWD_FAST_EXP2=0")
# API enablement flags (match CMakeLists.txt conditional defines)
flags.append("-DCK_TILE_FMHA_FWD_SPLITKV_API=1")
flags.append("-DCK_TILE_FMHA_FWD_APPENDKV_API=1")
flags.append("-DCK_TILE_FMHA_FWD_PAGEDKV_API=1")
# BWD-specific flags
if family.startswith("bwd"):
flags.append("-DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3")
return flags
def _make_splitkv_combine_config(splitkv_cfg: FmhaKernelConfig) -> FmhaKernelConfig:
"""Create a matching fwd_splitkv_combine config for a fwd_splitkv config.
Source: fmha_fwd.py splitkv_combine tile — fixed (32, hdim_v, 32, 32) tile.
The combine_bn1=32 comes from specs.py load_arch_specs() splitkv_combine dict.
The combine kernel merges partial results from the split stage into the
final output. Must be in the same .so as the split kernel for the
2-stage splitkv pipeline.
"""
import copy
comb = copy.copy(splitkv_cfg)
comb.family = "fwd_splitkv_combine"
comb.pipeline = "splitkv_combine"
hv = splitkv_cfg.hdim_v
comb.hdim_q = hv
comb.hdim_v = hv
comb.tile_m0 = 32
comb.tile_n0 = hv
comb.tile_k0 = 32
comb.tile_n1 = 32
comb.tile_k1 = 0
comb.tile_k0max = 0
comb.pad_s = 1 if splitkv_cfg.mode == "group" else 0
comb.pad_sk = 1
comb.pad_d = 1
comb.pad_dv = 1
comb.lse = True
# Combine doesn't use mask/bias/etc., but the dispatcher's supports() check
# matches the combine kernel's signature against the problem traits.
# Keep them from the split config so the signatures match.
comb.dropout = False
comb.skip_min_seqlen_q = False
comb.qscale = "no"
comb.rope = "none"
return comb
def _make_bwd_dot_do_o_config(dq_cfg: FmhaKernelConfig) -> FmhaKernelConfig:
"""Create a matching bwd_dot_do_o config for a bwd_dq_dk_dv config.
Source: fmha_bwd.py FmhaBwdDotDoOTileSize — fixed tile (64, max(hv,128), 32).
Warp tile (32,32,16) with 4 waves in M = standard fp16/bf16 MFMA config.
The dot_do_o kernel computes d = rowsum(O * dO) and must be in the same
.so as the dq_dk_dv kernel for the 2-stage BWD pipeline.
"""
import copy
dot = copy.copy(dq_cfg)
dot.family = "bwd_dot_do_o"
dot.pipeline = "qr"
hq, hv = dq_cfg.hdim_q, dq_cfg.hdim_v
dot.tile_m0 = 64
dot.tile_n0 = max(hv, 128)
dot.tile_k0 = 32
dot.tile_n1 = max(hv, 128)
dot.tile_k1 = 32
dot.tile_k0max = max(hq, 128)
dot.wave_m0 = 4
dot.wave_n0 = 1
dot.wave_k0 = 1
dot.wave_m1 = 4
dot.wave_n1 = 1
dot.wave_k1 = 1
dot.warp_m0 = 32
dot.warp_n0 = 32
dot.warp_k0 = 16
dot.warp_m1 = 32
dot.warp_n1 = 32
dot.warp_k1 = 16
dot.use_trload = False
# dot_do_o uses all-padded for maximum compatibility
dot.pad_s = 1
dot.pad_sk = 1
dot.pad_d = 1
dot.pad_dv = 1
# BWD traits don't have logits/sink/skip/lse/paged_kv -- from_invocation
# defaults them to false/0. The dot_do_o signature must match these defaults.
dot.logits = False
dot.sink = False
dot.skip_min_seqlen_q = False
dot.lse = False
dot.paged_kv = False
dot.qscale = "no"
dot.rope = "no"
# dot_do_o must match the problem's is_store_randval (from traits);
# keep dropout_variant as-is so store_randval matches
return dot
def setup_fmha_dispatcher(
config: FmhaKernelConfig,
output_dir: Optional[Path] = None,
verbose: bool = False,
) -> FmhaSetupResult:
"""JIT-compile a single FMHA kernel and return a runner.
Cached: if the .so already exists, loads it directly (~1ms).
Fresh build: codegen → parallel compile (kernel + ctypes) → link.
"""
import time
t0 = time.perf_counter()
root = get_dispatcher_root()
codegen_dir = root / "codegen"
ctypes_src = root / "bindings" / "ctypes" / "fmha_ctypes_lib.cpp"
static_lib = _find_static_lib()
hipcc = _find_hipcc()
if output_dir is None:
output_dir = root / "build" / "examples" / f"fmha_jit_{config.name}"
output_dir.mkdir(parents=True, exist_ok=True)
lib_name = f"libdispatcher_fmha_{config.name}.so"
lib_path = output_dir / lib_name
# Cache hit: .so already exists, just load
if lib_path.exists():
try:
runner = FmhaRunner.from_library(str(lib_path), config.gfx_arch)
return FmhaSetupResult(
success=True,
config=config,
runner=runner,
library_path=str(lib_path),
build_time_s=time.perf_counter() - t0,
)
except Exception:
pass # stale .so, rebuild
if not static_lib:
return FmhaSetupResult(
success=False, config=config, error="libck_tile_dispatcher.a not found"
)
if not ctypes_src.exists():
return FmhaSetupResult(
success=False, config=config, error="fmha_ctypes_lib.cpp not found"
)
# Step 1: Codegen
# BWD dq_dk_dv needs a matching dot_do_o kernel in the same .so
# BWD dq_dk_dv needs matching dot_do_o kernel for the 2-stage pipeline
if config.family == "bwd_dq_dk_dv":
dot_cfg = _make_bwd_dot_do_o_config(config)
config_json_str = json.dumps(
[
json.loads(dot_cfg.to_codegen_json()),
json.loads(config.to_codegen_json()),
]
)
else:
config_json_str = config.to_codegen_json()
gen_cmd = [
sys.executable,
str(codegen_dir / "fmha" / "generate_fallback.py"),
"--output-dir",
str(output_dir),
"--gpu-target",
config.gfx_arch,
"--config-json",
config_json_str,
]
r = subprocess.run(gen_cmd, capture_output=True, text=True, cwd=str(codegen_dir))
if r.returncode != 0:
return FmhaSetupResult(
success=False, config=config, error=f"Codegen failed: {r.stderr[:500]}"
)
dispatch_header = output_dir / "fmha_python_dispatch.hpp"
if not dispatch_header.exists():
return FmhaSetupResult(
success=False, config=config, error="Dispatch header not generated"
)
# Step 2: Compile kernel .cpp AND ctypes in parallel
kernel_cpps = list(output_dir.glob("fmha_*.cpp"))
base_flags = fmha_compile_flags(config.gfx_arch, hipcc, family=config.family)
compile_jobs = []
for cpp in kernel_cpps:
obj = cpp.with_suffix(".o")
compile_jobs.append((base_flags + [str(cpp), "-o", str(obj)], obj, "kernel"))
ctypes_obj = output_dir / "fmha_ctypes_lib.o"
ctypes_cmd = base_flags + [
f"-I{output_dir}",
f"-I{output_dir / 'dispatcher_wrappers'}",
f"-include{dispatch_header}",
f'-DGFX_ARCH="{config.gfx_arch}"',
str(ctypes_src),
"-o",
str(ctypes_obj),
]
compile_jobs.append((ctypes_cmd, ctypes_obj, "ctypes"))
def _run_compile(job):
cmd, obj, label = job
if obj.exists():
return (True, obj, label, "")
r = subprocess.run(cmd, capture_output=True, text=True)
return (r.returncode == 0, obj, label, r.stderr[:500])
with ThreadPoolExecutor(max_workers=len(compile_jobs)) as pool:
results = list(pool.map(_run_compile, compile_jobs))
kernel_objs = []
for ok, obj, label, err in results:
if not ok:
return FmhaSetupResult(
success=False,
config=config,
error=f"{label} compile failed: {err}",
)
if label == "kernel":
kernel_objs.append(str(obj))
# Step 3: Link
link_cmd = [
hipcc,
"-shared",
"-fPIC",
str(ctypes_obj),
*kernel_objs,
str(static_lib),
"-o",
str(lib_path),
]
r = subprocess.run(link_cmd, capture_output=True, text=True)
if r.returncode != 0:
return FmhaSetupResult(
success=False, config=config, error=f"Link failed: {r.stderr[:500]}"
)
# Step 4: Load
try:
runner = FmhaRunner.from_library(str(lib_path), config.gfx_arch)
except Exception as e:
return FmhaSetupResult(success=False, config=config, error=f"Load failed: {e}")
elapsed = time.perf_counter() - t0
return FmhaSetupResult(
success=True,
config=config,
runner=runner,
library_path=str(lib_path),
build_time_s=elapsed,
)
def _run_compile_job(job):
"""Module-level compile worker -- no threads, uses file-based stderr."""
cmd, obj_str, name, label = job
if os.path.exists(obj_str):
return (name, True, "")
err_path = obj_str + ".err"
with open(err_path, "w") as ef:
rc = subprocess.call(cmd, stdout=subprocess.DEVNULL, stderr=ef)
if rc != 0:
try:
err = open(err_path).read()[:200]
except Exception:
err = f"rc={rc}"
return (name, False, err)
try:
os.unlink(err_path)
except OSError:
pass
return (name, True, "")
def setup_multiple_fmha_dispatchers(
configs: List[FmhaKernelConfig],
output_dir: Optional[Path] = None,
verbose: bool = False,
max_workers: Optional[int] = None,
executor=None,
progress_callback=None,
) -> List[FmhaSetupResult]:
"""3-stage pipelined JIT: codegen(parallel) -> compile(parallel) -> link+load(parallel).
Faster than calling setup_fmha_dispatcher() per-kernel because all hipcc
compile jobs (kernel + ctypes from ALL kernels) share one thread pool.
"""
if not configs:
return []
root = get_dispatcher_root()
codegen_dir = root / "codegen"
ctypes_src = root / "bindings" / "ctypes" / "fmha_ctypes_lib.cpp"
static_lib = _find_static_lib()
hipcc = _find_hipcc()
arch = configs[0].gfx_arch
if output_dir is None:
output_dir = root / "build" / "examples"
results: dict[str, FmhaSetupResult] = {}
# --- Stage 1: Codegen (sequential, skip cached) ---
def _codegen(cfg):
out = output_dir / f"fmha_jit_{cfg.name}"
lib_path = out / f"libdispatcher_fmha_{cfg.name}.so"
# Fast path: .so exists, register result and skip
if lib_path.exists():
results[cfg.name] = FmhaSetupResult(
success=True, config=cfg, library_path=str(lib_path)
)
return (cfg.name, cfg, out, True)
# Fast path: previous codegen already failed (no .hpp generated)
if out.exists() and not (out / "fmha_python_dispatch.hpp").exists():
err_file = out / "_codegen_err.txt"
if err_file.exists():
results[cfg.name] = FmhaSetupResult(
success=False, config=cfg, error="Codegen failed (cached)"
)
return (cfg.name, cfg, out, False)
out.mkdir(parents=True, exist_ok=True)
# Check if codegen was already done (has .hpp but no .so yet)
if (out / "fmha_python_dispatch.hpp").exists():
return (cfg.name, cfg, out, True)
if cfg.family == "bwd_dq_dk_dv":
dot = _make_bwd_dot_do_o_config(cfg)
config_json_str = json.dumps(
[
json.loads(dot.to_codegen_json()),
json.loads(cfg.to_codegen_json()),
]
)
elif cfg.family == "fwd_splitkv":
comb = _make_splitkv_combine_config(cfg)
config_json_str = json.dumps(
[
json.loads(cfg.to_codegen_json()),
json.loads(comb.to_codegen_json()),
]
)
else:
config_json_str = cfg.to_codegen_json()
err_file = out / "_codegen_err.txt"
with open(err_file, "w") as ef:
rc = subprocess.call(
[
sys.executable,
str(codegen_dir / "fmha" / "generate_fallback.py"),
"--output-dir",
str(out),
"--gpu-target",
cfg.gfx_arch,
"--config-json",
config_json_str,
],
stdout=subprocess.DEVNULL,
stderr=ef,
cwd=str(codegen_dir),
)
ok = rc == 0 and (out / "fmha_python_dispatch.hpp").exists()
if not ok:
err_msg = err_file.read_text()[:200] if err_file.exists() else "unknown"
results[cfg.name] = FmhaSetupResult(
success=False, config=cfg, error=f"Codegen failed: {err_msg}"
)
return (cfg.name, cfg, out, ok)
codegen_results = []
for i, cfg in enumerate(configs):
codegen_results.append(_codegen(cfg))
if progress_callback:
progress_callback("codegen", i + 1, len(configs))
# --- Stage 2: Collect ALL compile jobs, run in one pool ---
# Use bwd family flag to get the superset of all flags (includes BWD-specific defines)
base_flags = fmha_compile_flags(arch, hipcc, family="bwd")
compile_jobs = [] # (cmd, obj_path, kernel_name, label)
config_dirs: dict[str, tuple[FmhaKernelConfig, Path]] = {}
for name, cfg, out, ok in codegen_results:
if not ok or name in results:
continue
config_dirs[name] = (cfg, out)
for cpp in out.glob("fmha_*.cpp"):
obj = cpp.with_suffix(".o")
if not obj.exists():
compile_jobs.append(
(base_flags + [str(cpp), "-o", str(obj)], str(obj), name, "kernel")
)
ctypes_obj = out / "fmha_ctypes_lib.o"
if not ctypes_obj.exists():
dispatch = out / "fmha_python_dispatch.hpp"
compile_jobs.append(
(
base_flags
+ [
f"-I{out}",
f"-I{out / 'dispatcher_wrappers'}",
f"-include{dispatch}",
f'-DGFX_ARCH="{arch}"',
str(ctypes_src),
"-o",
str(ctypes_obj),
],
str(ctypes_obj),
name,
"ctypes",
)
)
failed_names: set = set()
if compile_jobs:
_own_pool = None
_pool = executor
if _pool is None:
workers = max_workers or min(len(compile_jobs), os.cpu_count() or 4)
_own_pool = ProcessPoolExecutor(max_workers=workers)
_pool = _own_pool
try:
done_count = 0
total_jobs = len(compile_jobs)
for name, ok, err in _pool.map(_run_compile_job, compile_jobs):
done_count += 1
if progress_callback:
progress_callback("compile", done_count, total_jobs)
if not ok:
failed_names.add(name)
if name not in results:
cfg, _ = config_dirs[name]
results[name] = FmhaSetupResult(
success=False, config=cfg, error=f"Compile: {err}"
)
finally:
if _own_pool is not None:
_own_pool.shutdown(wait=True)
# --- Stage 3: Link (no GPU access -- runner loading deferred to caller) ---
def _link(item):
name, (cfg, out) = item
if name in failed_names or name in results:
return
objs = list(out.glob("*.o"))
lib_path = out / f"libdispatcher_fmha_{name}.so"
if not lib_path.exists():
r = subprocess.run(
[
hipcc,
"-shared",
"-fPIC",
*[str(o) for o in objs],
str(static_lib),
"-o",
str(lib_path),
],
capture_output=True,
text=True,
)
if r.returncode != 0:
results[name] = FmhaSetupResult(
success=False, config=cfg, error=f"Link: {r.stderr[:200]}"
)
return
results[name] = FmhaSetupResult(
success=True, config=cfg, library_path=str(lib_path)
)
for item in config_dirs.items():
_link(item)
# Return in original order
return [
results.get(c.name, FmhaSetupResult(success=False, config=c, error="skipped"))
for c in configs
]
# =============================================================================
# Registry (mirrors ctypes_utils.Registry)
# =============================================================================
class FmhaRegistry:
"""Kernel registry with parallel JIT build support."""
def __init__(self, name: str = "fmha"):
self._name = name
self._kernels: List[FmhaKernelConfig] = []
def register_kernel(self, config: FmhaKernelConfig):
self._kernels.append(config)
def __len__(self):
return len(self._kernels)
def build(
self,
verbose: bool = False,
max_workers: Optional[int] = None,
) -> List[FmhaSetupResult]:
return setup_multiple_fmha_dispatchers(
self._kernels,
verbose=verbose,
max_workers=max_workers,
)
# =============================================================================
# Validator (mirrors ctypes_utils.Validator)
# =============================================================================
class FmhaValidator:
"""Validates FMHA GPU output against a reference.
Usage:
validator = FmhaValidator(rtol=1e-2, atol=1e-2)
ok, max_abs, max_rel = validator.check(gpu_output, cpu_reference)
"""
def __init__(self, rtol: float = 1e-2, atol: float = 1e-2):
self.rtol = rtol
self.atol = atol
def check(
self, output: np.ndarray, reference: np.ndarray
) -> Tuple[bool, float, float]:
"""Check output against reference.
Returns:
(is_valid, max_abs_error, max_rel_error)
"""
out_f32 = output.astype(np.float32)
ref_f32 = reference.astype(np.float32)
diff = np.abs(out_f32 - ref_f32)
max_abs = float(diff.max())
max_rel = float((diff / (np.abs(ref_f32) + 1e-6)).max())
ok = bool(np.allclose(out_f32, ref_f32, atol=self.atol, rtol=self.rtol))
return ok, max_abs, max_rel
# =============================================================================
# KernelSpec + spec_to_config (mirrors ctypes_utils.KernelSpec)
# =============================================================================
@dataclass
class FmhaKernelSpec:
"""High-level kernel specification for easy declaration.
Mirrors GEMM's KernelSpec: specify name + key dimensions, get a
full FmhaKernelConfig via spec_to_config().
"""
name: str
hdim: int = 128
pipeline: str = "qr_async"
# Stage 0 tile (Q*K^T)
tile_m0: int = 128
tile_n0: int = 128
tile_k0: int = 32
def spec_to_config(
spec: FmhaKernelSpec, dtype: str = "fp16", arch: str = "gfx950"
) -> FmhaKernelConfig:
"""Convert a high-level FmhaKernelSpec to a full FmhaKernelConfig."""
hdim = spec.hdim
return FmhaKernelConfig(
data_type=dtype,
hdim_q=hdim,
hdim_v=hdim,
pipeline=spec.pipeline,
tile_m0=spec.tile_m0,
tile_n0=spec.tile_n0,
tile_k0=spec.tile_k0,
tile_n1=hdim,
tile_k1=spec.tile_k0,
tile_k0max=hdim,
gfx_arch=arch,
)
# =============================================================================
# Split-K heuristic (from fmhaarch.md Section 9.5)
# =============================================================================