mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-25 07:14:37 +00:00
[CK] [CK_Tile] Add FMHA scaffolding to CK kernel dispatcher (#5260) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation The CK Tile dispatcher currently supports GEMM and Grouped Convolution but has no support for Fused Multi-Head Attention (FMHA). The example/ck_tile/01_fmha folder contains a comprehensive FMHA implementation with forward, backward, split-KV, paged-KV, append-KV, and batch-prefill kernels across multiple GPU architectures — but there is no unified dispatch layer for it. This PR ports the FMHA stack into the dispatcher, following the same architectural patterns established by GEMM and Grouped Convolution, enabling runtime kernel selection, JIT compilation from Python, and a declarative C++ example flow. Autotuning heuristics to follow. ## Technical Details This PR adds FMHA scaffolding to the CK dispatcher framework, mirroring GEMM's layered architecture. Seven new C++ runtime headers provide type definitions (coexisting with upstream headers via __has_include, requiring zero modifications to example/ck_tile/01_fmha/), a problem builder with 18+ setters, Signature + Algorithm kernel key matching, a virtual kernel instance, a DECL_FMHA_KERNEL_SET macro with wildcard support and named tile/wave/warp setters, arch-aware registry with JSON export, and a dispatcher with seqtune-aware selection, configurable timing, and multi-stage execution plans for split-KV (two-stage) and backward (three-stage). The codegen pipeline is driven by a fmha_arch_specs.json capturing per-arch tile tables and pipeline constraints for five architectures (gfx90a/942/950/1100/1201), migrated from hardcoded logic in 01_fmha/codegen/, with supporting modules for C++ symbol mappings, validation rules, and named receipt profiles (ck_default, flash, pytorch, aiter, fp32, fp8). Python integration (fmha_utils.py) mirrors the C++ layer with JIT compilation, parallel multi-kernel builds, HIP memory management via ctypes, tolerance-based validation, and a NumPy CPU reference with GQA support. Twenty-seven C++ and thirty-two Python examples cover the full feature surface — forward, split-KV, masks, bias, dropout, GQA, backward, append-KV, batch prefill, fp8, logits soft cap, sink tokens, and parameter sweeps — all JIT-compiled on the fly. ## Test Plan Seven test files cover the runtime types, codegen, and end-to-end correctness. C++ unit tests validate the problem builder, dispatcher planning (single-stage for forward/paged-KV/append-KV; multi-stage for split-KV and backward), registry operations, and the kernel-set declaration macro. Python unit tests verify codegen emission, profile filtering, and 15 validation rules for masks, hdim constraints, and pipeline requirements. GPU execution validation in 01_basic_fmha --validate reports zero errors across 65,536 elements with max absolute error of 7.29e-05. A gold-standard parity suite (test_fmha_parity.py) runs 14 configurations through both the upstream tile_example_fmha_fwd and the dispatcher, comparing exit codes to confirm behavioral parity — all 14 match. ## Test Result The C++ smoke test builds and passes all 9 compiled examples, and a Python JIT sweep (29_sweep_seqlen.py) passes 7/7 configurations reaching up to 375 TFLOPS at seqlen 2048. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
1843 lines
60 KiB
Python
1843 lines
60 KiB
Python
#!/usr/bin/env python3
|
|
|
|
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
"""
|
|
FMHA Dispatcher Python Utilities
|
|
|
|
Provides Python wrappers for FMHA dispatcher kernels via ctypes.
|
|
Mirrors ctypes_utils.py (GEMM) and grouped_conv_utils.py (Conv).
|
|
|
|
Usage:
|
|
from fmha_utils import FmhaDispatcherLib, FmhaRunner, FmhaProblem, cpu_attention_fwd
|
|
|
|
runner = FmhaRunner.from_prebuilt()
|
|
result = runner.run(Q, K, V, problem)
|
|
"""
|
|
|
|
import ctypes
|
|
import json
|
|
import os
|
|
import subprocess
|
|
import sys
|
|
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import List, Optional, Tuple
|
|
|
|
import numpy as np
|
|
|
|
|
|
# =============================================================================
|
|
# Utility helpers
|
|
# =============================================================================
|
|
|
|
|
|
try:
|
|
from dispatcher_common import detect_gpu_arch, get_dispatcher_root
|
|
except ImportError:
|
|
# Standalone usage without dispatcher_common on PYTHONPATH
|
|
def get_dispatcher_root() -> Path:
|
|
return Path(__file__).parent.parent
|
|
|
|
def detect_gpu_arch(fallback: str = "gfx950") -> str:
|
|
try:
|
|
out = subprocess.check_output(
|
|
["rocminfo"], text=True, stderr=subprocess.DEVNULL
|
|
)
|
|
for line in out.splitlines():
|
|
if "Name:" in line and "gfx" in line:
|
|
return line.split()[-1].strip()
|
|
except Exception:
|
|
pass
|
|
return fallback
|
|
|
|
|
|
# =============================================================================
|
|
# Data types
|
|
# =============================================================================
|
|
|
|
|
|
@dataclass
|
|
class FmhaResult:
|
|
success: bool
|
|
output: Optional[np.ndarray] = None
|
|
time_ms: float = 0.0
|
|
tflops: float = 0.0
|
|
error: str = ""
|
|
|
|
|
|
@dataclass
|
|
class FmhaProblem:
|
|
batch: int = 2
|
|
nhead_q: int = 8
|
|
nhead_k: int = 8
|
|
seqlen_q: int = 128
|
|
seqlen_k: int = 128
|
|
hdim_q: int = 128
|
|
hdim_v: int = 128
|
|
|
|
@property
|
|
def scale(self) -> float:
|
|
return 1.0 / (self.hdim_q**0.5)
|
|
|
|
@property
|
|
def num_ops(self) -> int:
|
|
sq, sk = self.seqlen_q, self.seqlen_k
|
|
return 2 * self.batch * self.nhead_q * sq * sk * (self.hdim_q + self.hdim_v)
|
|
|
|
def q_shape(self):
|
|
return (self.batch, self.nhead_q, self.seqlen_q, self.hdim_q)
|
|
|
|
def k_shape(self):
|
|
return (self.batch, self.nhead_k, self.seqlen_k, self.hdim_q)
|
|
|
|
def v_shape(self):
|
|
return (self.batch, self.nhead_k, self.seqlen_k, self.hdim_v)
|
|
|
|
def o_shape(self):
|
|
return (self.batch, self.nhead_q, self.seqlen_q, self.hdim_v)
|
|
|
|
|
|
@dataclass
|
|
class FmhaKernelConfig:
|
|
"""Complete kernel configuration for FMHA.
|
|
|
|
All tile/wave/warp dimensions are explicitly named to match the
|
|
GEMM pattern (tile_m, tile_n, tile_k) but extended for FMHA's
|
|
two-stage computation (Q*K^T stage 0, Attn*V stage 1).
|
|
"""
|
|
|
|
# -- Signature: what operation --
|
|
family: str = "fwd"
|
|
data_type: str = "fp16"
|
|
mode: str = "batch"
|
|
vlayout: str = "r"
|
|
hdim_q: int = 128
|
|
hdim_v: int = 128
|
|
gfx_arch: str = "gfx950"
|
|
|
|
# -- Algorithm: tile shape --
|
|
# Stage 0 (Q * K^T): seqlen_q x seqlen_k x hdim_q
|
|
tile_m0: int = 128 # seqlen_q tile
|
|
tile_n0: int = 128 # seqlen_k tile
|
|
tile_k0: int = 32 # hdim_q tile
|
|
# Stage 1 (Attn * V): seqlen_q x hdim_v x seqlen_k
|
|
tile_n1: int = 128 # hdim_v tile
|
|
tile_k1: int = 32 # seqlen_k tile
|
|
tile_k0max: int = 128 # max k0 (alignment)
|
|
# BWD extra stages (9-element tile)
|
|
tile_bwd6: int = 0
|
|
tile_bwd7: int = 0
|
|
tile_bwd8: int = 0
|
|
|
|
# -- Algorithm: wave config (warps per block) --
|
|
wave_m0: int = 4
|
|
wave_n0: int = 1
|
|
wave_k0: int = 1
|
|
wave_m1: int = 4
|
|
wave_n1: int = 1
|
|
wave_k1: int = 1
|
|
wave_m2: int = 1
|
|
wave_n2: int = 1
|
|
wave_k2: int = 1
|
|
|
|
# -- Algorithm: warp tile (elements per warp) --
|
|
warp_m0: int = 32
|
|
warp_n0: int = 32
|
|
warp_k0: int = 16
|
|
warp_m1: int = 32
|
|
warp_n1: int = 32
|
|
warp_k1: int = 16
|
|
warp_m2: int = 16
|
|
warp_n2: int = 16
|
|
warp_k2: int = 16
|
|
|
|
# -- Algorithm: padding --
|
|
# Values: 0=no pad, 1=pad, 8=pad with 8-byte alignment (BWD-specific)
|
|
pad_s: int = 1
|
|
pad_sk: int = 1
|
|
pad_d: int = 1
|
|
pad_dv: int = 1
|
|
|
|
# -- Algorithm: pipeline --
|
|
pipeline: str = "qr_async"
|
|
block_per_cu: int = -1
|
|
num_wave_groups: int = 1
|
|
|
|
# -- Signature: features --
|
|
mask: str = "no"
|
|
bias: str = "no"
|
|
lse: bool = False
|
|
dropout: bool = False
|
|
qscale: str = "no"
|
|
rope: str = "none"
|
|
logits: bool = False
|
|
paged_kv: bool = False
|
|
sink: bool = False
|
|
skip_min_seqlen_q: bool = False
|
|
page_size: int = 1
|
|
kv_memory_layout: str = "vectorized"
|
|
kv_lookup_table: str = "sglang"
|
|
deterministic: bool = False
|
|
dbias: bool = False
|
|
dropout_variant: str = "" # BWD: "no"/"dropout_wg16"/"dropout_wg16_storerandval"
|
|
tile_tag: str = "" # extra tile variant discriminator (e.g. "trload", "small")
|
|
use_trload: bool = False # BWD dq_dk_dv: use trload pipeline path
|
|
|
|
@property
|
|
def tile(self) -> Tuple[int, ...]:
|
|
base = (
|
|
self.tile_m0,
|
|
self.tile_n0,
|
|
self.tile_k0,
|
|
self.tile_n1,
|
|
self.tile_k1,
|
|
self.tile_k0max,
|
|
)
|
|
if self.family == "bwd_dq_dk_dv" and self.tile_bwd6 > 0:
|
|
return base + (self.tile_bwd6, self.tile_bwd7, self.tile_bwd8)
|
|
return base
|
|
|
|
@property
|
|
def wave(self) -> Tuple[int, ...]:
|
|
return (
|
|
self.wave_m0,
|
|
self.wave_n0,
|
|
self.wave_k0,
|
|
self.wave_m1,
|
|
self.wave_n1,
|
|
self.wave_k1,
|
|
self.wave_m2,
|
|
self.wave_n2,
|
|
self.wave_k2,
|
|
)
|
|
|
|
@property
|
|
def warp(self) -> Tuple[int, ...]:
|
|
return (
|
|
self.warp_m0,
|
|
self.warp_n0,
|
|
self.warp_k0,
|
|
self.warp_m1,
|
|
self.warp_n1,
|
|
self.warp_k1,
|
|
self.warp_m2,
|
|
self.warp_n2,
|
|
self.warp_k2,
|
|
)
|
|
|
|
@property
|
|
def padding(self) -> Tuple[bool, ...]:
|
|
return (self.pad_s, self.pad_sk, self.pad_d, self.pad_dv)
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
s = self.pad_s
|
|
k = self.pad_sk
|
|
d = self.pad_d
|
|
v = self.pad_dv
|
|
parts = [
|
|
f"fmha_{self.family}_{self.data_type}",
|
|
self.mode,
|
|
f"h{self.hdim_q}x{self.hdim_v}"
|
|
if self.hdim_q != self.hdim_v
|
|
else f"h{self.hdim_q}",
|
|
self.pipeline,
|
|
f"t{self.tile_m0}x{self.tile_n0}x{self.tile_k0}x{self.tile_n1}x{self.tile_k1}x{self.tile_k0max}"
|
|
+ (f".{self.tile_tag}" if self.tile_tag else ""),
|
|
]
|
|
# Always include warp class for uniform naming
|
|
parts.append(f"w{self.warp_m0}x{self.warp_n0}x{self.warp_k0}")
|
|
parts.extend(
|
|
[
|
|
f"pad{s}{k}{d}{v}",
|
|
f"mask={self.mask}",
|
|
f"bias={self.bias}",
|
|
]
|
|
)
|
|
if self.lse:
|
|
parts.append("lse=1")
|
|
if self.dropout:
|
|
parts.append("drop=1")
|
|
if self.logits:
|
|
parts.append("logits=1")
|
|
if self.sink:
|
|
parts.append("sink=1")
|
|
if self.skip_min_seqlen_q:
|
|
parts.append("skip=1")
|
|
if self.qscale != "no":
|
|
parts.append(f"qs={self.qscale}")
|
|
if self.paged_kv:
|
|
parts.append("pkv=1")
|
|
if self.rope != "none":
|
|
parts.append(f"rope={self.rope}")
|
|
if self.page_size != 1:
|
|
parts.append(f"ps={self.page_size}")
|
|
if self.kv_memory_layout != "vectorized":
|
|
parts.append(f"kvl={self.kv_memory_layout}")
|
|
if self.kv_lookup_table != "sglang":
|
|
parts.append(f"kvt={self.kv_lookup_table}")
|
|
if self.deterministic:
|
|
parts.append("det=1")
|
|
if self.dbias:
|
|
parts.append("dbias=1")
|
|
if self.dropout_variant and self.dropout_variant != "no":
|
|
parts.append(f"drv={self.dropout_variant}")
|
|
# Always include block_per_cu for uniform naming
|
|
parts.append(f"bpc={self.block_per_cu}")
|
|
return "_".join(parts)
|
|
|
|
def to_codegen_json(self) -> str:
|
|
return json.dumps(
|
|
{
|
|
"arch": self.gfx_arch,
|
|
"signature": {
|
|
"family": self.family,
|
|
"data_type": self.data_type,
|
|
"mode": self.mode,
|
|
"vlayout": self.vlayout,
|
|
"hdim_q": self.hdim_q,
|
|
"hdim_v": self.hdim_v,
|
|
"mask": self.mask,
|
|
"bias": self.bias,
|
|
"lse": self.lse,
|
|
"dropout": self.dropout,
|
|
"qscale": self.qscale,
|
|
"rope": self.rope,
|
|
"logits": self.logits,
|
|
"paged_kv": self.paged_kv,
|
|
"fp8_static_quant": False,
|
|
"skip_min_seqlen_q": self.skip_min_seqlen_q,
|
|
"sink": self.sink,
|
|
"dbias": self.dbias,
|
|
"store_randval": "storerandval" in self.dropout_variant,
|
|
"deterministic": self.deterministic,
|
|
"dropout_variant": self.dropout_variant,
|
|
"kv_memory_layout": self.kv_memory_layout,
|
|
"kv_lookup_table": self.kv_lookup_table,
|
|
"page_size": self.page_size,
|
|
},
|
|
"algorithm": {
|
|
"pipeline": self.pipeline,
|
|
"tile": list(self.tile),
|
|
"wave": list(self.wave),
|
|
"warp": list(self.warp),
|
|
"padding": list(self.padding),
|
|
"block_per_cu": self.block_per_cu,
|
|
"num_wave_groups": self.num_wave_groups,
|
|
"max_splits_log2": 0,
|
|
"max_seq_len_q": 0,
|
|
"use_trload": self.use_trload,
|
|
},
|
|
}
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# CPU reference
|
|
# =============================================================================
|
|
|
|
|
|
def _float32_to_bf16(arr: np.ndarray) -> np.ndarray:
|
|
"""Convert float32 array to bf16 stored as uint16 (truncate lower 16 bits)."""
|
|
return arr.astype(np.float32).view(np.uint32).__rshift__(16).astype(np.uint16)
|
|
|
|
|
|
def _bf16_to_float32(arr: np.ndarray) -> np.ndarray:
|
|
"""Convert bf16 (uint16) array back to float32."""
|
|
return (arr.astype(np.uint32) << 16).view(np.float32)
|
|
|
|
|
|
def cpu_attention_fwd(
|
|
Q: np.ndarray,
|
|
K: np.ndarray,
|
|
V: np.ndarray,
|
|
scale: float,
|
|
mask_type: int = 0,
|
|
) -> np.ndarray:
|
|
"""CPU reference: scaled dot-product attention (supports GQA and causal mask).
|
|
|
|
Args:
|
|
Q: [batch, nhead_q, seqlen_q, hdim_q] float32
|
|
K: [batch, nhead_k, seqlen_k, hdim_q] float32
|
|
V: [batch, nhead_k, seqlen_k, hdim_v] float32
|
|
mask_type: 0=no mask, 1=causal top-left, 2=causal bottom-right
|
|
|
|
Returns:
|
|
O: [batch, nhead_q, seqlen_q, hdim_v] float32
|
|
"""
|
|
nhead_q = Q.shape[1]
|
|
nhead_k = K.shape[1]
|
|
if nhead_q != nhead_k:
|
|
ratio = nhead_q // nhead_k
|
|
K = np.repeat(K, ratio, axis=1)
|
|
V = np.repeat(V, ratio, axis=1)
|
|
S = np.matmul(Q, K.transpose(0, 1, 3, 2)) * scale
|
|
if mask_type in (1, 2):
|
|
sq, sk = S.shape[-2], S.shape[-1]
|
|
row = np.arange(sq).reshape(sq, 1)
|
|
col = np.arange(sk).reshape(1, sk)
|
|
if mask_type == 1: # top-left causal
|
|
causal_mask = col <= row
|
|
else: # bottom-right causal
|
|
causal_mask = col <= (row + sk - sq)
|
|
S = np.where(causal_mask, S, -1e9)
|
|
S_max = S.max(axis=-1, keepdims=True)
|
|
S_exp = np.exp(S - S_max)
|
|
P = S_exp / S_exp.sum(axis=-1, keepdims=True)
|
|
return np.matmul(P, V)
|
|
|
|
|
|
def cpu_attention_fwd_with_intermediates(
|
|
Q: np.ndarray, K: np.ndarray, V: np.ndarray, scale: float
|
|
) -> tuple:
|
|
"""CPU reference forward returning (output, P) for backward use.
|
|
|
|
Same as cpu_attention_fwd but also returns the softmax probability matrix P.
|
|
"""
|
|
nhead_q = Q.shape[1]
|
|
nhead_k = K.shape[1]
|
|
if nhead_q != nhead_k:
|
|
ratio = nhead_q // nhead_k
|
|
K = np.repeat(K, ratio, axis=1)
|
|
V = np.repeat(V, ratio, axis=1)
|
|
S = np.matmul(Q, K.transpose(0, 1, 3, 2)) * scale
|
|
S_max = S.max(axis=-1, keepdims=True)
|
|
S_exp = np.exp(S - S_max)
|
|
P = S_exp / S_exp.sum(axis=-1, keepdims=True)
|
|
out = np.matmul(P, V)
|
|
return out, P
|
|
|
|
|
|
def cpu_attention_bwd(
|
|
Q: np.ndarray,
|
|
K: np.ndarray,
|
|
V: np.ndarray,
|
|
out: np.ndarray,
|
|
dO: np.ndarray,
|
|
P: np.ndarray,
|
|
scale: float,
|
|
) -> tuple:
|
|
"""CPU reference backward. Returns (dQ, dK, dV).
|
|
|
|
Args:
|
|
Q, K, V: forward inputs [batch, heads, seq, dim]
|
|
out: forward output
|
|
dO: gradient of output
|
|
P: softmax probabilities from forward
|
|
scale: attention scale factor
|
|
"""
|
|
D = (dO * out).sum(axis=-1, keepdims=True)
|
|
dP = np.matmul(dO, V.transpose(0, 1, 3, 2))
|
|
dS = P * (dP - D)
|
|
dQ = np.matmul(dS, K) * scale
|
|
dK = np.matmul(dS.transpose(0, 1, 3, 2), Q) * scale
|
|
dV = np.matmul(P.transpose(0, 1, 3, 2), dO)
|
|
return dQ, dK, dV
|
|
|
|
|
|
# =============================================================================
|
|
# Low-level ctypes wrapper
|
|
# =============================================================================
|
|
|
|
|
|
class FmhaDispatcherLib:
|
|
"""Wrapper for the FMHA dispatcher shared library (libdispatcher_fmha_lib.so)."""
|
|
|
|
SEARCH_PATHS = [
|
|
"build/examples/libdispatcher_fmha_lib.so",
|
|
"build/libdispatcher_fmha_lib.so",
|
|
"build/lib/libdispatcher_fmha_lib.so",
|
|
]
|
|
|
|
def __init__(self, lib: ctypes.CDLL, path: Path):
|
|
self._lib = lib
|
|
self.path = path
|
|
self._setup()
|
|
|
|
def _setup(self):
|
|
lib = self._lib
|
|
lib.fmha_dispatcher_initialize.argtypes = [ctypes.c_char_p]
|
|
lib.fmha_dispatcher_initialize.restype = ctypes.c_int
|
|
lib.fmha_dispatcher_run_fwd.argtypes = [
|
|
ctypes.c_void_p, # q
|
|
ctypes.c_void_p, # k
|
|
ctypes.c_void_p, # v
|
|
ctypes.c_void_p, # o
|
|
ctypes.c_int, # batch
|
|
ctypes.c_int, # nhead_q
|
|
ctypes.c_int, # nhead_k
|
|
ctypes.c_int, # seqlen_q
|
|
ctypes.c_int, # seqlen_k
|
|
ctypes.c_int, # hdim_q
|
|
ctypes.c_int, # hdim_v
|
|
ctypes.c_float, # scale
|
|
ctypes.c_int, # mask_type
|
|
ctypes.c_int, # bias_type
|
|
ctypes.c_int, # has_lse
|
|
ctypes.c_int, # has_dropout
|
|
ctypes.c_int, # traits_hdim_q (0=same as hdim_q)
|
|
ctypes.c_int, # traits_hdim_v (0=same as hdim_v)
|
|
ctypes.c_int, # is_v_rowmajor (1=row, 0=col)
|
|
ctypes.c_int, # perm (1=BHSD, 0=BSHD)
|
|
ctypes.c_char_p, # data_type ("fp16", "bf16")
|
|
ctypes.c_int, # is_group_mode
|
|
ctypes.c_int, # window_left (-1=no window)
|
|
ctypes.c_int, # window_right (-1=no window, 0=causal)
|
|
ctypes.c_int, # has_logits
|
|
ctypes.c_int, # has_sink
|
|
ctypes.c_int, # has_skip
|
|
ctypes.POINTER(ctypes.c_float), # time_ms_out
|
|
]
|
|
lib.fmha_dispatcher_run_fwd.restype = ctypes.c_int
|
|
lib.fmha_dispatcher_run_bwd.argtypes = [
|
|
ctypes.c_void_p, # q
|
|
ctypes.c_void_p, # k
|
|
ctypes.c_void_p, # v
|
|
ctypes.c_void_p, # o
|
|
ctypes.c_void_p, # lse
|
|
ctypes.c_void_p, # do
|
|
ctypes.c_void_p, # dq
|
|
ctypes.c_void_p, # dk
|
|
ctypes.c_void_p, # dv
|
|
ctypes.c_int, # batch
|
|
ctypes.c_int, # nhead_q
|
|
ctypes.c_int, # nhead_k
|
|
ctypes.c_int, # seqlen_q
|
|
ctypes.c_int, # seqlen_k
|
|
ctypes.c_int, # hdim_q
|
|
ctypes.c_int, # hdim_v
|
|
ctypes.c_float, # scale
|
|
ctypes.c_char_p, # data_type_str
|
|
ctypes.c_int, # mask_type_int
|
|
ctypes.c_int, # bias_type_int
|
|
ctypes.c_int, # has_dropout
|
|
ctypes.c_int, # has_dbias
|
|
ctypes.c_int, # is_deterministic
|
|
ctypes.c_int, # is_group_mode
|
|
ctypes.c_int, # is_store_randval
|
|
ctypes.c_int, # tile_n0 (kN0 for nsplits computation)
|
|
ctypes.POINTER(ctypes.c_float), # time_ms_out
|
|
]
|
|
lib.fmha_dispatcher_run_bwd.restype = ctypes.c_int
|
|
|
|
# Split-KV forward
|
|
lib.fmha_dispatcher_run_splitkv.argtypes = [
|
|
ctypes.c_void_p,
|
|
ctypes.c_void_p,
|
|
ctypes.c_void_p,
|
|
ctypes.c_void_p,
|
|
ctypes.c_int,
|
|
ctypes.c_int,
|
|
ctypes.c_int,
|
|
ctypes.c_int,
|
|
ctypes.c_int,
|
|
ctypes.c_int,
|
|
ctypes.c_int,
|
|
ctypes.c_float,
|
|
ctypes.c_int, # mask_type
|
|
ctypes.c_int, # num_splits
|
|
ctypes.c_int, # is_v_rowmajor
|
|
ctypes.c_char_p,
|
|
ctypes.c_int, # has_lse
|
|
ctypes.c_int, # is_group_mode
|
|
ctypes.c_int, # perm
|
|
ctypes.c_int, # has_logits
|
|
ctypes.c_int, # bias_type
|
|
ctypes.c_int, # has_sink
|
|
ctypes.c_int, # paged_kv
|
|
ctypes.c_int, # page_block_size
|
|
ctypes.c_int, # window_left
|
|
ctypes.c_int, # window_right
|
|
ctypes.POINTER(ctypes.c_float),
|
|
]
|
|
lib.fmha_dispatcher_run_splitkv.restype = ctypes.c_int
|
|
|
|
# Paged-KV forward
|
|
lib.fmha_dispatcher_run_pagedkv.argtypes = [
|
|
ctypes.c_void_p,
|
|
ctypes.c_void_p,
|
|
ctypes.c_void_p,
|
|
ctypes.c_void_p,
|
|
ctypes.c_int,
|
|
ctypes.c_int,
|
|
ctypes.c_int,
|
|
ctypes.c_int,
|
|
ctypes.c_int,
|
|
ctypes.c_int,
|
|
ctypes.c_int,
|
|
ctypes.c_float,
|
|
ctypes.c_int, # mask_type
|
|
ctypes.c_int, # page_block_size
|
|
ctypes.c_int, # is_v_rowmajor
|
|
ctypes.c_char_p,
|
|
ctypes.c_int, # has_lse
|
|
ctypes.c_int, # has_logits
|
|
ctypes.c_int, # has_sink
|
|
ctypes.c_int, # skip_min_seqlen_q
|
|
ctypes.c_int, # bias_type
|
|
ctypes.POINTER(ctypes.c_float),
|
|
]
|
|
lib.fmha_dispatcher_run_pagedkv.restype = ctypes.c_int
|
|
|
|
# Append-KV
|
|
lib.fmha_dispatcher_run_appendkv.argtypes = [
|
|
ctypes.c_void_p,
|
|
ctypes.c_void_p,
|
|
ctypes.c_void_p,
|
|
ctypes.c_int,
|
|
ctypes.c_int,
|
|
ctypes.c_int,
|
|
ctypes.c_int,
|
|
ctypes.c_int,
|
|
ctypes.c_int,
|
|
ctypes.c_int,
|
|
ctypes.c_int, # is_v_rowmajor
|
|
ctypes.c_int, # rope_type
|
|
ctypes.c_int, # paged_kv
|
|
ctypes.c_int, # page_block_size
|
|
ctypes.c_char_p,
|
|
ctypes.POINTER(ctypes.c_float),
|
|
]
|
|
lib.fmha_dispatcher_run_appendkv.restype = ctypes.c_int
|
|
|
|
# Batch Prefill
|
|
lib.fmha_dispatcher_run_batch_prefill.argtypes = [
|
|
ctypes.c_void_p,
|
|
ctypes.c_void_p,
|
|
ctypes.c_void_p,
|
|
ctypes.c_void_p,
|
|
ctypes.c_int,
|
|
ctypes.c_int,
|
|
ctypes.c_int,
|
|
ctypes.c_int,
|
|
ctypes.c_int,
|
|
ctypes.c_int,
|
|
ctypes.c_int,
|
|
ctypes.c_float,
|
|
ctypes.c_int, # mask_type
|
|
ctypes.c_int, # bias_type
|
|
ctypes.c_int, # page_block_size
|
|
ctypes.c_int, # kv_layout_int
|
|
ctypes.c_int, # kv_lookup_int
|
|
ctypes.c_int, # is_v_rowmajor
|
|
ctypes.c_char_p,
|
|
ctypes.c_int, # has_lse
|
|
ctypes.c_int, # has_dropout
|
|
ctypes.c_int, # has_logits
|
|
ctypes.c_int, # has_sink
|
|
ctypes.c_int, # skip_min_seqlen_q
|
|
ctypes.POINTER(ctypes.c_float),
|
|
]
|
|
lib.fmha_dispatcher_run_batch_prefill.restype = ctypes.c_int
|
|
|
|
lib.fmha_dispatcher_kernel_count.argtypes = []
|
|
lib.fmha_dispatcher_kernel_count.restype = ctypes.c_int
|
|
lib.fmha_dispatcher_cleanup.argtypes = []
|
|
lib.fmha_dispatcher_cleanup.restype = None
|
|
|
|
@classmethod
|
|
def find(cls) -> Optional["FmhaDispatcherLib"]:
|
|
root = get_dispatcher_root()
|
|
for rel in cls.SEARCH_PATHS:
|
|
path = root / rel
|
|
if path.exists():
|
|
try:
|
|
lib = ctypes.CDLL(str(path))
|
|
return cls(lib, path)
|
|
except OSError:
|
|
continue
|
|
return None
|
|
|
|
@classmethod
|
|
def load(cls, path: str) -> "FmhaDispatcherLib":
|
|
lib = ctypes.CDLL(path)
|
|
return cls(lib, Path(path))
|
|
|
|
def initialize(self, arch: str = "gfx950") -> bool:
|
|
return self._lib.fmha_dispatcher_initialize(arch.encode()) == 0
|
|
|
|
def run_bwd(
|
|
self,
|
|
q: ctypes.c_void_p,
|
|
k: ctypes.c_void_p,
|
|
v: ctypes.c_void_p,
|
|
o: ctypes.c_void_p,
|
|
lse: ctypes.c_void_p,
|
|
do_grad: ctypes.c_void_p,
|
|
dq: ctypes.c_void_p,
|
|
dk: ctypes.c_void_p,
|
|
dv: ctypes.c_void_p,
|
|
prob: FmhaProblem,
|
|
data_type: str = "fp16",
|
|
mask_type: int = 0,
|
|
bias_type: int = 0,
|
|
has_dropout: bool = False,
|
|
has_dbias: bool = False,
|
|
is_deterministic: bool = False,
|
|
is_group_mode: bool = False,
|
|
is_store_randval: bool = False,
|
|
tile_n0: int = 128,
|
|
) -> Tuple[int, float]:
|
|
time_ms = ctypes.c_float(0.0)
|
|
rc = self._lib.fmha_dispatcher_run_bwd(
|
|
q,
|
|
k,
|
|
v,
|
|
o,
|
|
lse,
|
|
do_grad,
|
|
dq,
|
|
dk,
|
|
dv,
|
|
prob.batch,
|
|
prob.nhead_q,
|
|
prob.nhead_k,
|
|
prob.seqlen_q,
|
|
prob.seqlen_k,
|
|
prob.hdim_q,
|
|
prob.hdim_v,
|
|
prob.scale,
|
|
data_type.encode(),
|
|
ctypes.c_int(mask_type),
|
|
ctypes.c_int(bias_type),
|
|
ctypes.c_int(int(has_dropout)),
|
|
ctypes.c_int(int(has_dbias)),
|
|
ctypes.c_int(int(is_deterministic)),
|
|
ctypes.c_int(int(is_group_mode)),
|
|
ctypes.c_int(int(is_store_randval)),
|
|
ctypes.c_int(tile_n0),
|
|
ctypes.byref(time_ms),
|
|
)
|
|
return rc, time_ms.value
|
|
|
|
def kernel_count(self) -> int:
|
|
return self._lib.fmha_dispatcher_kernel_count()
|
|
|
|
def cleanup(self):
|
|
self._lib.fmha_dispatcher_cleanup()
|
|
|
|
|
|
# =============================================================================
|
|
# High-level GPU runner (mirrors GpuGroupedConvRunner)
|
|
# =============================================================================
|
|
|
|
|
|
class FmhaRunner:
|
|
"""High-level FMHA runner with NumPy interface and HIP memory management."""
|
|
|
|
HIP_MEMCPY_H2D = 1
|
|
HIP_MEMCPY_D2H = 2
|
|
|
|
def __init__(self, dispatch_lib: FmhaDispatcherLib, arch: str = "gfx950"):
|
|
self._lib = dispatch_lib
|
|
self._arch = arch
|
|
self._hip = None
|
|
self._load_hip()
|
|
if not dispatch_lib.initialize(arch):
|
|
raise RuntimeError("Failed to initialize FMHA dispatcher")
|
|
|
|
def _load_hip(self):
|
|
for name in ["libamdhip64.so", "libamdhip64.so.6"]:
|
|
try:
|
|
self._hip = ctypes.CDLL(name)
|
|
self._hip.hipMalloc.argtypes = [
|
|
ctypes.POINTER(ctypes.c_void_p),
|
|
ctypes.c_size_t,
|
|
]
|
|
self._hip.hipMalloc.restype = ctypes.c_int
|
|
self._hip.hipFree.argtypes = [ctypes.c_void_p]
|
|
self._hip.hipFree.restype = ctypes.c_int
|
|
self._hip.hipMemcpy.argtypes = [
|
|
ctypes.c_void_p,
|
|
ctypes.c_void_p,
|
|
ctypes.c_size_t,
|
|
ctypes.c_int,
|
|
]
|
|
self._hip.hipMemcpy.restype = ctypes.c_int
|
|
self._hip.hipMemset.argtypes = [
|
|
ctypes.c_void_p,
|
|
ctypes.c_int,
|
|
ctypes.c_size_t,
|
|
]
|
|
self._hip.hipMemset.restype = ctypes.c_int
|
|
return
|
|
except OSError:
|
|
continue
|
|
raise RuntimeError("Could not load libamdhip64.so")
|
|
|
|
@classmethod
|
|
def from_prebuilt(cls, arch: Optional[str] = None) -> "FmhaRunner":
|
|
arch = arch or detect_gpu_arch()
|
|
lib = FmhaDispatcherLib.find()
|
|
if lib is None:
|
|
raise RuntimeError(
|
|
"FMHA dispatcher library not found. Build with:\n"
|
|
" cd dispatcher/build && cmake .. -DBUILD_DISPATCHER_EXAMPLES=ON && make dispatcher_fmha_lib"
|
|
)
|
|
return cls(lib, arch)
|
|
|
|
@classmethod
|
|
def from_library(cls, path: str, arch: Optional[str] = None) -> "FmhaRunner":
|
|
arch = arch or detect_gpu_arch()
|
|
return cls(FmhaDispatcherLib.load(path), arch)
|
|
|
|
def run(
|
|
self,
|
|
Q: np.ndarray,
|
|
K: np.ndarray,
|
|
V: np.ndarray,
|
|
prob: FmhaProblem,
|
|
mask_type: int = 0,
|
|
bias_type: int = 0,
|
|
has_lse: int = 0,
|
|
has_dropout: int = 0,
|
|
has_logits: int = 0,
|
|
has_sink: int = 0,
|
|
has_skip: int = 0,
|
|
api_family: str = "fwd",
|
|
data_type: str = "fp16",
|
|
**kwargs,
|
|
) -> "FmhaResult":
|
|
"""Run FMHA forward on GPU with automatic HIP memory management.
|
|
|
|
Args:
|
|
Q: [batch, nhead_q, seqlen_q, hdim_q] float16
|
|
K: [batch, nhead_k, seqlen_k, hdim_q] float16
|
|
V: [batch, nhead_k, seqlen_k, hdim_v] float16
|
|
|
|
Returns:
|
|
FmhaResult with output array, timing, TFLOPS
|
|
"""
|
|
# Map CK dtype to numpy dtype for buffer allocation.
|
|
# bf16 is stored as uint16 (upper 16 bits of float32).
|
|
# fp8 uses uint8 (1 byte per element).
|
|
_NP_DTYPE = {
|
|
"fp16": np.float16,
|
|
"bf16": np.uint16,
|
|
"fp32": np.float32,
|
|
"fp8bf16": np.uint8,
|
|
"fp8fp32": np.uint8,
|
|
"bf8": np.uint8,
|
|
}
|
|
_NP_OUT_DTYPE = {
|
|
"fp16": np.float16,
|
|
"bf16": np.uint16,
|
|
"fp32": np.float32,
|
|
"fp8bf16": np.float16,
|
|
"fp8fp32": np.float32,
|
|
"bf8": np.uint8,
|
|
}
|
|
in_dt = _NP_DTYPE.get(data_type, np.float16)
|
|
out_dt = _NP_OUT_DTYPE.get(data_type, np.float16)
|
|
if data_type == "bf16":
|
|
Q_c = _float32_to_bf16(np.ascontiguousarray(Q.astype(np.float32)))
|
|
K_c = _float32_to_bf16(np.ascontiguousarray(K.astype(np.float32)))
|
|
V_c = _float32_to_bf16(np.ascontiguousarray(V.astype(np.float32)))
|
|
else:
|
|
Q_c = np.ascontiguousarray(Q.astype(in_dt))
|
|
K_c = np.ascontiguousarray(K.astype(in_dt))
|
|
V_c = np.ascontiguousarray(V.astype(in_dt))
|
|
O_c = np.zeros(prob.o_shape(), dtype=out_dt)
|
|
|
|
d_q, d_k, d_v, d_o = (ctypes.c_void_p() for _ in range(4))
|
|
|
|
try:
|
|
self._hip.hipMalloc(ctypes.byref(d_q), Q_c.nbytes)
|
|
self._hip.hipMalloc(ctypes.byref(d_k), K_c.nbytes)
|
|
self._hip.hipMalloc(ctypes.byref(d_v), V_c.nbytes)
|
|
self._hip.hipMalloc(ctypes.byref(d_o), O_c.nbytes)
|
|
|
|
self._hip.hipMemcpy(d_q, Q_c.ctypes.data, Q_c.nbytes, self.HIP_MEMCPY_H2D)
|
|
self._hip.hipMemcpy(d_k, K_c.ctypes.data, K_c.nbytes, self.HIP_MEMCPY_H2D)
|
|
self._hip.hipMemcpy(d_v, V_c.ctypes.data, V_c.nbytes, self.HIP_MEMCPY_H2D)
|
|
self._hip.hipMemset(d_o, 0, O_c.nbytes)
|
|
|
|
time_ms = ctypes.c_float(0.0)
|
|
lib = self._lib._lib
|
|
|
|
is_v_rowmajor = kwargs.get("is_v_rowmajor", 1)
|
|
is_group_mode = kwargs.get("is_group_mode", 0)
|
|
perm = kwargs.get("perm", 1)
|
|
window_left = kwargs.get("window_left", -1)
|
|
window_right = kwargs.get("window_right", -1)
|
|
num_splits = kwargs.get("num_splits", 4)
|
|
page_size = kwargs.get("page_size", 64)
|
|
kv_layout = kwargs.get("kv_layout", 0)
|
|
kv_lookup = kwargs.get("kv_lookup", 0)
|
|
traits_hdim_q = kwargs.get("traits_hdim_q", 0)
|
|
traits_hdim_v = kwargs.get("traits_hdim_v", 0)
|
|
|
|
if api_family == "splitkv":
|
|
paged_kv = kwargs.get("paged_kv", 0)
|
|
rc = lib.fmha_dispatcher_run_splitkv(
|
|
d_q,
|
|
d_k,
|
|
d_v,
|
|
d_o,
|
|
prob.batch,
|
|
prob.nhead_q,
|
|
prob.nhead_k,
|
|
prob.seqlen_q,
|
|
prob.seqlen_k,
|
|
prob.hdim_q,
|
|
prob.hdim_v,
|
|
prob.scale,
|
|
mask_type,
|
|
num_splits,
|
|
is_v_rowmajor,
|
|
data_type.encode(),
|
|
has_lse,
|
|
is_group_mode,
|
|
perm,
|
|
has_logits,
|
|
bias_type,
|
|
has_sink,
|
|
paged_kv,
|
|
page_size,
|
|
window_left,
|
|
window_right,
|
|
ctypes.byref(time_ms),
|
|
)
|
|
elif api_family == "pagedkv":
|
|
rc = lib.fmha_dispatcher_run_pagedkv(
|
|
d_q,
|
|
d_k,
|
|
d_v,
|
|
d_o,
|
|
prob.batch,
|
|
prob.nhead_q,
|
|
prob.nhead_k,
|
|
prob.seqlen_q,
|
|
prob.seqlen_k,
|
|
prob.hdim_q,
|
|
prob.hdim_v,
|
|
prob.scale,
|
|
mask_type,
|
|
page_size,
|
|
is_v_rowmajor,
|
|
data_type.encode(),
|
|
has_lse,
|
|
has_logits,
|
|
has_sink,
|
|
has_skip,
|
|
bias_type,
|
|
ctypes.byref(time_ms),
|
|
)
|
|
elif api_family == "appendkv":
|
|
seqlen_knew = kwargs.get("seqlen_knew", prob.seqlen_k)
|
|
rc = lib.fmha_dispatcher_run_appendkv(
|
|
Q_c.ctypes.data,
|
|
K_c.ctypes.data,
|
|
V_c.ctypes.data,
|
|
prob.batch,
|
|
prob.nhead_q,
|
|
prob.nhead_k,
|
|
prob.seqlen_q,
|
|
seqlen_knew,
|
|
prob.hdim_q,
|
|
prob.hdim_v,
|
|
is_v_rowmajor,
|
|
kwargs.get("rope_type", 0),
|
|
kwargs.get("paged_kv", 0),
|
|
page_size,
|
|
data_type.encode(),
|
|
ctypes.byref(time_ms),
|
|
)
|
|
elif api_family == "batch_prefill":
|
|
skip_min_sq = kwargs.get("skip_min_seqlen_q", 0)
|
|
rc = lib.fmha_dispatcher_run_batch_prefill(
|
|
d_q,
|
|
d_k,
|
|
d_v,
|
|
d_o,
|
|
prob.batch,
|
|
prob.nhead_q,
|
|
prob.nhead_k,
|
|
prob.seqlen_q,
|
|
prob.seqlen_k,
|
|
prob.hdim_q,
|
|
prob.hdim_v,
|
|
prob.scale,
|
|
mask_type,
|
|
bias_type,
|
|
page_size,
|
|
kv_layout,
|
|
kv_lookup,
|
|
is_v_rowmajor,
|
|
data_type.encode(),
|
|
has_lse,
|
|
has_dropout,
|
|
has_logits,
|
|
has_sink,
|
|
skip_min_sq,
|
|
ctypes.byref(time_ms),
|
|
)
|
|
else:
|
|
rc = lib.fmha_dispatcher_run_fwd(
|
|
d_q,
|
|
d_k,
|
|
d_v,
|
|
d_o,
|
|
prob.batch,
|
|
prob.nhead_q,
|
|
prob.nhead_k,
|
|
prob.seqlen_q,
|
|
prob.seqlen_k,
|
|
prob.hdim_q,
|
|
prob.hdim_v,
|
|
prob.scale,
|
|
mask_type,
|
|
bias_type,
|
|
has_lse,
|
|
has_dropout,
|
|
traits_hdim_q,
|
|
traits_hdim_v,
|
|
is_v_rowmajor,
|
|
perm,
|
|
data_type.encode(),
|
|
is_group_mode,
|
|
window_left,
|
|
window_right,
|
|
has_logits,
|
|
has_sink,
|
|
has_skip,
|
|
ctypes.byref(time_ms),
|
|
)
|
|
|
|
if rc != 0:
|
|
return FmhaResult(success=False, error=f"Kernel failed (rc={rc})")
|
|
|
|
self._hip.hipMemcpy(O_c.ctypes.data, d_o, O_c.nbytes, self.HIP_MEMCPY_D2H)
|
|
|
|
# Convert bf16 output (uint16) back to float32 for comparison
|
|
if data_type == "bf16":
|
|
O_c = _bf16_to_float32(O_c)
|
|
|
|
# appendkv is a memory op (KV cache copy), not compute -- no TFLOPS
|
|
ops = 0 if api_family == "appendkv" else prob.num_ops
|
|
tflops = (
|
|
ops / (time_ms.value * 1e-3) / 1e12
|
|
if time_ms.value > 0 and ops > 0
|
|
else 0.0
|
|
)
|
|
return FmhaResult(
|
|
success=True, output=O_c, time_ms=time_ms.value, tflops=tflops
|
|
)
|
|
|
|
finally:
|
|
for d in [d_q, d_k, d_v, d_o]:
|
|
if d.value:
|
|
self._hip.hipFree(d)
|
|
|
|
def run_bwd(
|
|
self,
|
|
Q: np.ndarray,
|
|
K: np.ndarray,
|
|
V: np.ndarray,
|
|
out: np.ndarray,
|
|
LSE: np.ndarray,
|
|
dO: np.ndarray,
|
|
prob: FmhaProblem,
|
|
data_type: str = "fp16",
|
|
mask_type: int = 0,
|
|
bias_type: int = 0,
|
|
has_dropout: bool = False,
|
|
has_dbias: bool = False,
|
|
is_deterministic: bool = False,
|
|
is_group_mode: bool = False,
|
|
is_store_randval: bool = False,
|
|
tile_n0: int = 128,
|
|
) -> "FmhaResult":
|
|
"""Run FMHA backward on GPU with automatic HIP memory management.
|
|
|
|
Returns FmhaResult with dQ, dK, dV packed in output as a tuple.
|
|
"""
|
|
_NP_DTYPE = {
|
|
"fp16": np.float16,
|
|
"bf16": np.float16,
|
|
"fp32": np.float32,
|
|
"fp8bf16": np.uint8,
|
|
"fp8fp32": np.uint8,
|
|
"bf8": np.uint8,
|
|
}
|
|
in_dt = _NP_DTYPE.get(data_type, np.float16)
|
|
Q_c = np.ascontiguousarray(Q.astype(in_dt))
|
|
K_c = np.ascontiguousarray(K.astype(in_dt))
|
|
V_c = np.ascontiguousarray(V.astype(in_dt))
|
|
O_c = np.ascontiguousarray(out.astype(in_dt))
|
|
LSE_c = np.ascontiguousarray(LSE.astype(np.float32))
|
|
dO_c = np.ascontiguousarray(dO.astype(in_dt))
|
|
dQ_c = np.zeros_like(Q_c)
|
|
dK_c = np.zeros_like(K_c)
|
|
dV_c = np.zeros_like(V_c)
|
|
|
|
ptrs = [ctypes.c_void_p() for _ in range(9)]
|
|
d_q, d_k, d_v, d_o, d_lse, d_do, d_dq, d_dk, d_dv = ptrs
|
|
|
|
try:
|
|
for d, arr in zip(ptrs[:6], [Q_c, K_c, V_c, O_c, LSE_c, dO_c]):
|
|
self._hip.hipMalloc(ctypes.byref(d), arr.nbytes)
|
|
self._hip.hipMemcpy(d, arr.ctypes.data, arr.nbytes, self.HIP_MEMCPY_H2D)
|
|
for d, arr in zip(ptrs[6:], [dQ_c, dK_c, dV_c]):
|
|
self._hip.hipMalloc(ctypes.byref(d), arr.nbytes)
|
|
self._hip.hipMemset(d, 0, arr.nbytes)
|
|
|
|
rc, elapsed = self._lib.run_bwd(
|
|
d_q,
|
|
d_k,
|
|
d_v,
|
|
d_o,
|
|
d_lse,
|
|
d_do,
|
|
d_dq,
|
|
d_dk,
|
|
d_dv,
|
|
prob,
|
|
data_type,
|
|
mask_type=mask_type,
|
|
bias_type=bias_type,
|
|
has_dropout=has_dropout,
|
|
has_dbias=has_dbias,
|
|
is_deterministic=is_deterministic,
|
|
is_group_mode=is_group_mode,
|
|
is_store_randval=is_store_randval,
|
|
tile_n0=tile_n0,
|
|
)
|
|
|
|
if rc != 0:
|
|
return FmhaResult(success=False, error=f"BWD kernel failed (rc={rc})")
|
|
|
|
for d, arr in zip(ptrs[6:], [dQ_c, dK_c, dV_c]):
|
|
self._hip.hipMemcpy(arr.ctypes.data, d, arr.nbytes, self.HIP_MEMCPY_D2H)
|
|
|
|
tflops = prob.num_ops / (elapsed * 1e-3) / 1e12 if elapsed > 0 else 0.0
|
|
return FmhaResult(
|
|
success=True,
|
|
output=(dQ_c, dK_c, dV_c),
|
|
time_ms=elapsed,
|
|
tflops=tflops,
|
|
)
|
|
finally:
|
|
for d in ptrs:
|
|
if d.value:
|
|
self._hip.hipFree(d)
|
|
|
|
@property
|
|
def kernel_count(self) -> int:
|
|
return self._lib.kernel_count()
|
|
|
|
@property
|
|
def library_path(self) -> str:
|
|
return str(self._lib.path)
|
|
|
|
def cleanup(self):
|
|
self._lib.cleanup()
|
|
|
|
|
|
# =============================================================================
|
|
# JIT Build Support (mirrors setup_multiple_gemm_dispatchers)
|
|
# =============================================================================
|
|
|
|
|
|
@dataclass
|
|
class FmhaSetupResult:
|
|
success: bool
|
|
config: Optional[FmhaKernelConfig] = None
|
|
runner: Optional[FmhaRunner] = None
|
|
library_path: str = ""
|
|
error: str = ""
|
|
build_time_s: float = 0.0
|
|
|
|
|
|
def _build_static_lib(root: Path) -> Optional[Path]:
|
|
"""Build libck_tile_dispatcher.a via cmake if not already present."""
|
|
build_dir = root / "build"
|
|
build_dir.mkdir(parents=True, exist_ok=True)
|
|
hipcc = _find_hipcc()
|
|
cmake_cmd = ["cmake", str(root), f"-DCMAKE_CXX_COMPILER={hipcc}"]
|
|
r = subprocess.run(cmake_cmd, cwd=str(build_dir), capture_output=True, text=True)
|
|
if r.returncode != 0:
|
|
print(
|
|
f"Warning: cmake failed for dispatcher lib: {r.stderr[:200]}",
|
|
file=sys.stderr,
|
|
)
|
|
return None
|
|
make_cmd = ["make", "ck_tile_dispatcher", f"-j{os.cpu_count() or 4}"]
|
|
r = subprocess.run(make_cmd, cwd=str(build_dir), capture_output=True, text=True)
|
|
if r.returncode != 0:
|
|
print(
|
|
f"Warning: make failed for dispatcher lib: {r.stderr[:200]}",
|
|
file=sys.stderr,
|
|
)
|
|
return None
|
|
lib_path = build_dir / "libck_tile_dispatcher.a"
|
|
return lib_path if lib_path.exists() else None
|
|
|
|
|
|
def _find_static_lib() -> Optional[Path]:
|
|
root = get_dispatcher_root()
|
|
for rel in ["build/libck_tile_dispatcher.a", "build/lib/libck_tile_dispatcher.a"]:
|
|
p = root / rel
|
|
if p.exists():
|
|
return p
|
|
# Auto-build if not found
|
|
print(" Building libck_tile_dispatcher.a (first time)...", file=sys.stderr)
|
|
return _build_static_lib(root)
|
|
|
|
|
|
def _find_hipcc() -> str:
|
|
for path in ["/opt/rocm/bin/hipcc", "/usr/bin/hipcc"]:
|
|
if os.path.exists(path):
|
|
return path
|
|
return "hipcc"
|
|
|
|
|
|
def fmha_compile_flags(arch: str, hipcc: str = "", family: str = "") -> List[str]:
|
|
"""Base hipcc flags for compiling FMHA kernels. Shared by JIT and tile engine.
|
|
|
|
Source: example/ck_tile/01_fmha/CMakeLists.txt — mirrors CK's own build
|
|
flags to ensure parity. Key defines:
|
|
- CK_TILE_FMHA_FWD_FAST_EXP2: enables fast exp2 on gfx9 (CDNA)
|
|
- CK_TILE_USE_OCP_FP8: uses OCP standard fp8 format
|
|
- CK_GFX950_SUPPORT / CK_USE_GFX950: enables gfx950-specific code paths
|
|
- CK_USE_XDL: enables MFMA (matrix fused multiply-add) instructions
|
|
- CK_TILE_USE_WMMA: 0 for CDNA (uses MFMA instead)
|
|
- CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3: BWD bf16 conversion mode
|
|
"""
|
|
if not hipcc:
|
|
hipcc = _find_hipcc()
|
|
root = get_dispatcher_root()
|
|
flags = [
|
|
hipcc,
|
|
"-c",
|
|
"-fPIC",
|
|
"-O3",
|
|
"-DNDEBUG",
|
|
f"--offload-arch={arch}",
|
|
"-std=c++17",
|
|
f"-I{root.parent / 'include'}",
|
|
f"-I{root / 'include'}",
|
|
f"-I{root.parent}",
|
|
"-Wno-undefined-func-template",
|
|
"-Wno-float-equal",
|
|
"-fgpu-flush-denormals-to-zero",
|
|
"-fno-offload-uniform-block",
|
|
"-mllvm",
|
|
"--lsr-drop-solution=1",
|
|
"-mllvm",
|
|
"-enable-post-misched=0",
|
|
"-mllvm",
|
|
"-amdgpu-early-inline-all=true",
|
|
"-mllvm",
|
|
"-amdgpu-function-calls=false",
|
|
]
|
|
if arch.startswith("gfx9"):
|
|
flags.append("-DCK_TILE_FMHA_FWD_FAST_EXP2=1")
|
|
flags.append("-DCK_TILE_USE_OCP_FP8")
|
|
flags.append("-DCK_GFX950_SUPPORT")
|
|
flags.append("-DCK_USE_GFX950")
|
|
flags.append("-DCK_USE_GFX94")
|
|
flags.append("-DCK_USE_XDL")
|
|
flags.append("-DCK_TILE_USE_WMMA=0")
|
|
else:
|
|
flags.append("-DCK_TILE_FMHA_FWD_FAST_EXP2=0")
|
|
|
|
# API enablement flags (match CMakeLists.txt conditional defines)
|
|
flags.append("-DCK_TILE_FMHA_FWD_SPLITKV_API=1")
|
|
flags.append("-DCK_TILE_FMHA_FWD_APPENDKV_API=1")
|
|
flags.append("-DCK_TILE_FMHA_FWD_PAGEDKV_API=1")
|
|
|
|
# BWD-specific flags
|
|
if family.startswith("bwd"):
|
|
flags.append("-DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3")
|
|
|
|
return flags
|
|
|
|
|
|
def _make_splitkv_combine_config(splitkv_cfg: FmhaKernelConfig) -> FmhaKernelConfig:
|
|
"""Create a matching fwd_splitkv_combine config for a fwd_splitkv config.
|
|
|
|
Source: fmha_fwd.py splitkv_combine tile — fixed (32, hdim_v, 32, 32) tile.
|
|
The combine_bn1=32 comes from specs.py load_arch_specs() splitkv_combine dict.
|
|
The combine kernel merges partial results from the split stage into the
|
|
final output. Must be in the same .so as the split kernel for the
|
|
2-stage splitkv pipeline.
|
|
"""
|
|
import copy
|
|
|
|
comb = copy.copy(splitkv_cfg)
|
|
comb.family = "fwd_splitkv_combine"
|
|
comb.pipeline = "splitkv_combine"
|
|
hv = splitkv_cfg.hdim_v
|
|
comb.hdim_q = hv
|
|
comb.hdim_v = hv
|
|
comb.tile_m0 = 32
|
|
comb.tile_n0 = hv
|
|
comb.tile_k0 = 32
|
|
comb.tile_n1 = 32
|
|
comb.tile_k1 = 0
|
|
comb.tile_k0max = 0
|
|
comb.pad_s = 1 if splitkv_cfg.mode == "group" else 0
|
|
comb.pad_sk = 1
|
|
comb.pad_d = 1
|
|
comb.pad_dv = 1
|
|
comb.lse = True
|
|
# Combine doesn't use mask/bias/etc., but the dispatcher's supports() check
|
|
# matches the combine kernel's signature against the problem traits.
|
|
# Keep them from the split config so the signatures match.
|
|
comb.dropout = False
|
|
comb.skip_min_seqlen_q = False
|
|
comb.qscale = "no"
|
|
comb.rope = "none"
|
|
return comb
|
|
|
|
|
|
def _make_bwd_dot_do_o_config(dq_cfg: FmhaKernelConfig) -> FmhaKernelConfig:
|
|
"""Create a matching bwd_dot_do_o config for a bwd_dq_dk_dv config.
|
|
|
|
Source: fmha_bwd.py FmhaBwdDotDoOTileSize — fixed tile (64, max(hv,128), 32).
|
|
Warp tile (32,32,16) with 4 waves in M = standard fp16/bf16 MFMA config.
|
|
The dot_do_o kernel computes d = rowsum(O * dO) and must be in the same
|
|
.so as the dq_dk_dv kernel for the 2-stage BWD pipeline.
|
|
"""
|
|
import copy
|
|
|
|
dot = copy.copy(dq_cfg)
|
|
dot.family = "bwd_dot_do_o"
|
|
dot.pipeline = "qr"
|
|
hq, hv = dq_cfg.hdim_q, dq_cfg.hdim_v
|
|
dot.tile_m0 = 64
|
|
dot.tile_n0 = max(hv, 128)
|
|
dot.tile_k0 = 32
|
|
dot.tile_n1 = max(hv, 128)
|
|
dot.tile_k1 = 32
|
|
dot.tile_k0max = max(hq, 128)
|
|
dot.wave_m0 = 4
|
|
dot.wave_n0 = 1
|
|
dot.wave_k0 = 1
|
|
dot.wave_m1 = 4
|
|
dot.wave_n1 = 1
|
|
dot.wave_k1 = 1
|
|
dot.warp_m0 = 32
|
|
dot.warp_n0 = 32
|
|
dot.warp_k0 = 16
|
|
dot.warp_m1 = 32
|
|
dot.warp_n1 = 32
|
|
dot.warp_k1 = 16
|
|
dot.use_trload = False
|
|
# dot_do_o uses all-padded for maximum compatibility
|
|
dot.pad_s = 1
|
|
dot.pad_sk = 1
|
|
dot.pad_d = 1
|
|
dot.pad_dv = 1
|
|
# BWD traits don't have logits/sink/skip/lse/paged_kv -- from_invocation
|
|
# defaults them to false/0. The dot_do_o signature must match these defaults.
|
|
dot.logits = False
|
|
dot.sink = False
|
|
dot.skip_min_seqlen_q = False
|
|
dot.lse = False
|
|
dot.paged_kv = False
|
|
dot.qscale = "no"
|
|
dot.rope = "no"
|
|
# dot_do_o must match the problem's is_store_randval (from traits);
|
|
# keep dropout_variant as-is so store_randval matches
|
|
return dot
|
|
|
|
|
|
def setup_fmha_dispatcher(
|
|
config: FmhaKernelConfig,
|
|
output_dir: Optional[Path] = None,
|
|
verbose: bool = False,
|
|
) -> FmhaSetupResult:
|
|
"""JIT-compile a single FMHA kernel and return a runner.
|
|
|
|
Cached: if the .so already exists, loads it directly (~1ms).
|
|
Fresh build: codegen → parallel compile (kernel + ctypes) → link.
|
|
"""
|
|
import time
|
|
|
|
t0 = time.perf_counter()
|
|
|
|
root = get_dispatcher_root()
|
|
codegen_dir = root / "codegen"
|
|
ctypes_src = root / "bindings" / "ctypes" / "fmha_ctypes_lib.cpp"
|
|
static_lib = _find_static_lib()
|
|
hipcc = _find_hipcc()
|
|
|
|
if output_dir is None:
|
|
output_dir = root / "build" / "examples" / f"fmha_jit_{config.name}"
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
lib_name = f"libdispatcher_fmha_{config.name}.so"
|
|
lib_path = output_dir / lib_name
|
|
|
|
# Cache hit: .so already exists, just load
|
|
if lib_path.exists():
|
|
try:
|
|
runner = FmhaRunner.from_library(str(lib_path), config.gfx_arch)
|
|
return FmhaSetupResult(
|
|
success=True,
|
|
config=config,
|
|
runner=runner,
|
|
library_path=str(lib_path),
|
|
build_time_s=time.perf_counter() - t0,
|
|
)
|
|
except Exception:
|
|
pass # stale .so, rebuild
|
|
|
|
if not static_lib:
|
|
return FmhaSetupResult(
|
|
success=False, config=config, error="libck_tile_dispatcher.a not found"
|
|
)
|
|
if not ctypes_src.exists():
|
|
return FmhaSetupResult(
|
|
success=False, config=config, error="fmha_ctypes_lib.cpp not found"
|
|
)
|
|
|
|
# Step 1: Codegen
|
|
# BWD dq_dk_dv needs a matching dot_do_o kernel in the same .so
|
|
# BWD dq_dk_dv needs matching dot_do_o kernel for the 2-stage pipeline
|
|
if config.family == "bwd_dq_dk_dv":
|
|
dot_cfg = _make_bwd_dot_do_o_config(config)
|
|
config_json_str = json.dumps(
|
|
[
|
|
json.loads(dot_cfg.to_codegen_json()),
|
|
json.loads(config.to_codegen_json()),
|
|
]
|
|
)
|
|
else:
|
|
config_json_str = config.to_codegen_json()
|
|
gen_cmd = [
|
|
sys.executable,
|
|
str(codegen_dir / "fmha" / "generate_fallback.py"),
|
|
"--output-dir",
|
|
str(output_dir),
|
|
"--gpu-target",
|
|
config.gfx_arch,
|
|
"--config-json",
|
|
config_json_str,
|
|
]
|
|
r = subprocess.run(gen_cmd, capture_output=True, text=True, cwd=str(codegen_dir))
|
|
if r.returncode != 0:
|
|
return FmhaSetupResult(
|
|
success=False, config=config, error=f"Codegen failed: {r.stderr[:500]}"
|
|
)
|
|
|
|
dispatch_header = output_dir / "fmha_python_dispatch.hpp"
|
|
if not dispatch_header.exists():
|
|
return FmhaSetupResult(
|
|
success=False, config=config, error="Dispatch header not generated"
|
|
)
|
|
|
|
# Step 2: Compile kernel .cpp AND ctypes in parallel
|
|
kernel_cpps = list(output_dir.glob("fmha_*.cpp"))
|
|
base_flags = fmha_compile_flags(config.gfx_arch, hipcc, family=config.family)
|
|
|
|
compile_jobs = []
|
|
for cpp in kernel_cpps:
|
|
obj = cpp.with_suffix(".o")
|
|
compile_jobs.append((base_flags + [str(cpp), "-o", str(obj)], obj, "kernel"))
|
|
|
|
ctypes_obj = output_dir / "fmha_ctypes_lib.o"
|
|
ctypes_cmd = base_flags + [
|
|
f"-I{output_dir}",
|
|
f"-I{output_dir / 'dispatcher_wrappers'}",
|
|
f"-include{dispatch_header}",
|
|
f'-DGFX_ARCH="{config.gfx_arch}"',
|
|
str(ctypes_src),
|
|
"-o",
|
|
str(ctypes_obj),
|
|
]
|
|
compile_jobs.append((ctypes_cmd, ctypes_obj, "ctypes"))
|
|
|
|
def _run_compile(job):
|
|
cmd, obj, label = job
|
|
if obj.exists():
|
|
return (True, obj, label, "")
|
|
r = subprocess.run(cmd, capture_output=True, text=True)
|
|
return (r.returncode == 0, obj, label, r.stderr[:500])
|
|
|
|
with ThreadPoolExecutor(max_workers=len(compile_jobs)) as pool:
|
|
results = list(pool.map(_run_compile, compile_jobs))
|
|
|
|
kernel_objs = []
|
|
for ok, obj, label, err in results:
|
|
if not ok:
|
|
return FmhaSetupResult(
|
|
success=False,
|
|
config=config,
|
|
error=f"{label} compile failed: {err}",
|
|
)
|
|
if label == "kernel":
|
|
kernel_objs.append(str(obj))
|
|
|
|
# Step 3: Link
|
|
link_cmd = [
|
|
hipcc,
|
|
"-shared",
|
|
"-fPIC",
|
|
str(ctypes_obj),
|
|
*kernel_objs,
|
|
str(static_lib),
|
|
"-o",
|
|
str(lib_path),
|
|
]
|
|
r = subprocess.run(link_cmd, capture_output=True, text=True)
|
|
if r.returncode != 0:
|
|
return FmhaSetupResult(
|
|
success=False, config=config, error=f"Link failed: {r.stderr[:500]}"
|
|
)
|
|
|
|
# Step 4: Load
|
|
try:
|
|
runner = FmhaRunner.from_library(str(lib_path), config.gfx_arch)
|
|
except Exception as e:
|
|
return FmhaSetupResult(success=False, config=config, error=f"Load failed: {e}")
|
|
|
|
elapsed = time.perf_counter() - t0
|
|
return FmhaSetupResult(
|
|
success=True,
|
|
config=config,
|
|
runner=runner,
|
|
library_path=str(lib_path),
|
|
build_time_s=elapsed,
|
|
)
|
|
|
|
|
|
def _run_compile_job(job):
|
|
"""Module-level compile worker -- no threads, uses file-based stderr."""
|
|
cmd, obj_str, name, label = job
|
|
if os.path.exists(obj_str):
|
|
return (name, True, "")
|
|
err_path = obj_str + ".err"
|
|
with open(err_path, "w") as ef:
|
|
rc = subprocess.call(cmd, stdout=subprocess.DEVNULL, stderr=ef)
|
|
if rc != 0:
|
|
try:
|
|
err = open(err_path).read()[:200]
|
|
except Exception:
|
|
err = f"rc={rc}"
|
|
return (name, False, err)
|
|
try:
|
|
os.unlink(err_path)
|
|
except OSError:
|
|
pass
|
|
return (name, True, "")
|
|
|
|
|
|
def setup_multiple_fmha_dispatchers(
|
|
configs: List[FmhaKernelConfig],
|
|
output_dir: Optional[Path] = None,
|
|
verbose: bool = False,
|
|
max_workers: Optional[int] = None,
|
|
executor=None,
|
|
progress_callback=None,
|
|
) -> List[FmhaSetupResult]:
|
|
"""3-stage pipelined JIT: codegen(parallel) -> compile(parallel) -> link+load(parallel).
|
|
|
|
Faster than calling setup_fmha_dispatcher() per-kernel because all hipcc
|
|
compile jobs (kernel + ctypes from ALL kernels) share one thread pool.
|
|
"""
|
|
if not configs:
|
|
return []
|
|
|
|
root = get_dispatcher_root()
|
|
codegen_dir = root / "codegen"
|
|
ctypes_src = root / "bindings" / "ctypes" / "fmha_ctypes_lib.cpp"
|
|
static_lib = _find_static_lib()
|
|
hipcc = _find_hipcc()
|
|
arch = configs[0].gfx_arch
|
|
|
|
if output_dir is None:
|
|
output_dir = root / "build" / "examples"
|
|
|
|
results: dict[str, FmhaSetupResult] = {}
|
|
|
|
# --- Stage 1: Codegen (sequential, skip cached) ---
|
|
def _codegen(cfg):
|
|
out = output_dir / f"fmha_jit_{cfg.name}"
|
|
lib_path = out / f"libdispatcher_fmha_{cfg.name}.so"
|
|
# Fast path: .so exists, register result and skip
|
|
if lib_path.exists():
|
|
results[cfg.name] = FmhaSetupResult(
|
|
success=True, config=cfg, library_path=str(lib_path)
|
|
)
|
|
return (cfg.name, cfg, out, True)
|
|
# Fast path: previous codegen already failed (no .hpp generated)
|
|
if out.exists() and not (out / "fmha_python_dispatch.hpp").exists():
|
|
err_file = out / "_codegen_err.txt"
|
|
if err_file.exists():
|
|
results[cfg.name] = FmhaSetupResult(
|
|
success=False, config=cfg, error="Codegen failed (cached)"
|
|
)
|
|
return (cfg.name, cfg, out, False)
|
|
out.mkdir(parents=True, exist_ok=True)
|
|
# Check if codegen was already done (has .hpp but no .so yet)
|
|
if (out / "fmha_python_dispatch.hpp").exists():
|
|
return (cfg.name, cfg, out, True)
|
|
if cfg.family == "bwd_dq_dk_dv":
|
|
dot = _make_bwd_dot_do_o_config(cfg)
|
|
config_json_str = json.dumps(
|
|
[
|
|
json.loads(dot.to_codegen_json()),
|
|
json.loads(cfg.to_codegen_json()),
|
|
]
|
|
)
|
|
elif cfg.family == "fwd_splitkv":
|
|
comb = _make_splitkv_combine_config(cfg)
|
|
config_json_str = json.dumps(
|
|
[
|
|
json.loads(cfg.to_codegen_json()),
|
|
json.loads(comb.to_codegen_json()),
|
|
]
|
|
)
|
|
else:
|
|
config_json_str = cfg.to_codegen_json()
|
|
err_file = out / "_codegen_err.txt"
|
|
with open(err_file, "w") as ef:
|
|
rc = subprocess.call(
|
|
[
|
|
sys.executable,
|
|
str(codegen_dir / "fmha" / "generate_fallback.py"),
|
|
"--output-dir",
|
|
str(out),
|
|
"--gpu-target",
|
|
cfg.gfx_arch,
|
|
"--config-json",
|
|
config_json_str,
|
|
],
|
|
stdout=subprocess.DEVNULL,
|
|
stderr=ef,
|
|
cwd=str(codegen_dir),
|
|
)
|
|
ok = rc == 0 and (out / "fmha_python_dispatch.hpp").exists()
|
|
if not ok:
|
|
err_msg = err_file.read_text()[:200] if err_file.exists() else "unknown"
|
|
results[cfg.name] = FmhaSetupResult(
|
|
success=False, config=cfg, error=f"Codegen failed: {err_msg}"
|
|
)
|
|
return (cfg.name, cfg, out, ok)
|
|
|
|
codegen_results = []
|
|
for i, cfg in enumerate(configs):
|
|
codegen_results.append(_codegen(cfg))
|
|
if progress_callback:
|
|
progress_callback("codegen", i + 1, len(configs))
|
|
|
|
# --- Stage 2: Collect ALL compile jobs, run in one pool ---
|
|
# Use bwd family flag to get the superset of all flags (includes BWD-specific defines)
|
|
base_flags = fmha_compile_flags(arch, hipcc, family="bwd")
|
|
compile_jobs = [] # (cmd, obj_path, kernel_name, label)
|
|
|
|
config_dirs: dict[str, tuple[FmhaKernelConfig, Path]] = {}
|
|
for name, cfg, out, ok in codegen_results:
|
|
if not ok or name in results:
|
|
continue
|
|
config_dirs[name] = (cfg, out)
|
|
for cpp in out.glob("fmha_*.cpp"):
|
|
obj = cpp.with_suffix(".o")
|
|
if not obj.exists():
|
|
compile_jobs.append(
|
|
(base_flags + [str(cpp), "-o", str(obj)], str(obj), name, "kernel")
|
|
)
|
|
ctypes_obj = out / "fmha_ctypes_lib.o"
|
|
if not ctypes_obj.exists():
|
|
dispatch = out / "fmha_python_dispatch.hpp"
|
|
compile_jobs.append(
|
|
(
|
|
base_flags
|
|
+ [
|
|
f"-I{out}",
|
|
f"-I{out / 'dispatcher_wrappers'}",
|
|
f"-include{dispatch}",
|
|
f'-DGFX_ARCH="{arch}"',
|
|
str(ctypes_src),
|
|
"-o",
|
|
str(ctypes_obj),
|
|
],
|
|
str(ctypes_obj),
|
|
name,
|
|
"ctypes",
|
|
)
|
|
)
|
|
|
|
failed_names: set = set()
|
|
|
|
if compile_jobs:
|
|
_own_pool = None
|
|
_pool = executor
|
|
if _pool is None:
|
|
workers = max_workers or min(len(compile_jobs), os.cpu_count() or 4)
|
|
_own_pool = ProcessPoolExecutor(max_workers=workers)
|
|
_pool = _own_pool
|
|
try:
|
|
done_count = 0
|
|
total_jobs = len(compile_jobs)
|
|
for name, ok, err in _pool.map(_run_compile_job, compile_jobs):
|
|
done_count += 1
|
|
if progress_callback:
|
|
progress_callback("compile", done_count, total_jobs)
|
|
if not ok:
|
|
failed_names.add(name)
|
|
if name not in results:
|
|
cfg, _ = config_dirs[name]
|
|
results[name] = FmhaSetupResult(
|
|
success=False, config=cfg, error=f"Compile: {err}"
|
|
)
|
|
finally:
|
|
if _own_pool is not None:
|
|
_own_pool.shutdown(wait=True)
|
|
|
|
# --- Stage 3: Link (no GPU access -- runner loading deferred to caller) ---
|
|
def _link(item):
|
|
name, (cfg, out) = item
|
|
if name in failed_names or name in results:
|
|
return
|
|
objs = list(out.glob("*.o"))
|
|
lib_path = out / f"libdispatcher_fmha_{name}.so"
|
|
if not lib_path.exists():
|
|
r = subprocess.run(
|
|
[
|
|
hipcc,
|
|
"-shared",
|
|
"-fPIC",
|
|
*[str(o) for o in objs],
|
|
str(static_lib),
|
|
"-o",
|
|
str(lib_path),
|
|
],
|
|
capture_output=True,
|
|
text=True,
|
|
)
|
|
if r.returncode != 0:
|
|
results[name] = FmhaSetupResult(
|
|
success=False, config=cfg, error=f"Link: {r.stderr[:200]}"
|
|
)
|
|
return
|
|
results[name] = FmhaSetupResult(
|
|
success=True, config=cfg, library_path=str(lib_path)
|
|
)
|
|
|
|
for item in config_dirs.items():
|
|
_link(item)
|
|
|
|
# Return in original order
|
|
return [
|
|
results.get(c.name, FmhaSetupResult(success=False, config=c, error="skipped"))
|
|
for c in configs
|
|
]
|
|
|
|
|
|
# =============================================================================
|
|
# Registry (mirrors ctypes_utils.Registry)
|
|
# =============================================================================
|
|
|
|
|
|
class FmhaRegistry:
|
|
"""Kernel registry with parallel JIT build support."""
|
|
|
|
def __init__(self, name: str = "fmha"):
|
|
self._name = name
|
|
self._kernels: List[FmhaKernelConfig] = []
|
|
|
|
def register_kernel(self, config: FmhaKernelConfig):
|
|
self._kernels.append(config)
|
|
|
|
def __len__(self):
|
|
return len(self._kernels)
|
|
|
|
def build(
|
|
self,
|
|
verbose: bool = False,
|
|
max_workers: Optional[int] = None,
|
|
) -> List[FmhaSetupResult]:
|
|
return setup_multiple_fmha_dispatchers(
|
|
self._kernels,
|
|
verbose=verbose,
|
|
max_workers=max_workers,
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# Validator (mirrors ctypes_utils.Validator)
|
|
# =============================================================================
|
|
|
|
|
|
class FmhaValidator:
|
|
"""Validates FMHA GPU output against a reference.
|
|
|
|
Usage:
|
|
validator = FmhaValidator(rtol=1e-2, atol=1e-2)
|
|
ok, max_abs, max_rel = validator.check(gpu_output, cpu_reference)
|
|
"""
|
|
|
|
def __init__(self, rtol: float = 1e-2, atol: float = 1e-2):
|
|
self.rtol = rtol
|
|
self.atol = atol
|
|
|
|
def check(
|
|
self, output: np.ndarray, reference: np.ndarray
|
|
) -> Tuple[bool, float, float]:
|
|
"""Check output against reference.
|
|
|
|
Returns:
|
|
(is_valid, max_abs_error, max_rel_error)
|
|
"""
|
|
out_f32 = output.astype(np.float32)
|
|
ref_f32 = reference.astype(np.float32)
|
|
diff = np.abs(out_f32 - ref_f32)
|
|
max_abs = float(diff.max())
|
|
max_rel = float((diff / (np.abs(ref_f32) + 1e-6)).max())
|
|
ok = bool(np.allclose(out_f32, ref_f32, atol=self.atol, rtol=self.rtol))
|
|
return ok, max_abs, max_rel
|
|
|
|
|
|
# =============================================================================
|
|
# KernelSpec + spec_to_config (mirrors ctypes_utils.KernelSpec)
|
|
# =============================================================================
|
|
|
|
|
|
@dataclass
|
|
class FmhaKernelSpec:
|
|
"""High-level kernel specification for easy declaration.
|
|
|
|
Mirrors GEMM's KernelSpec: specify name + key dimensions, get a
|
|
full FmhaKernelConfig via spec_to_config().
|
|
"""
|
|
|
|
name: str
|
|
hdim: int = 128
|
|
pipeline: str = "qr_async"
|
|
# Stage 0 tile (Q*K^T)
|
|
tile_m0: int = 128
|
|
tile_n0: int = 128
|
|
tile_k0: int = 32
|
|
|
|
|
|
def spec_to_config(
|
|
spec: FmhaKernelSpec, dtype: str = "fp16", arch: str = "gfx950"
|
|
) -> FmhaKernelConfig:
|
|
"""Convert a high-level FmhaKernelSpec to a full FmhaKernelConfig."""
|
|
hdim = spec.hdim
|
|
return FmhaKernelConfig(
|
|
data_type=dtype,
|
|
hdim_q=hdim,
|
|
hdim_v=hdim,
|
|
pipeline=spec.pipeline,
|
|
tile_m0=spec.tile_m0,
|
|
tile_n0=spec.tile_n0,
|
|
tile_k0=spec.tile_k0,
|
|
tile_n1=hdim,
|
|
tile_k1=spec.tile_k0,
|
|
tile_k0max=hdim,
|
|
gfx_arch=arch,
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# Split-K heuristic (from fmhaarch.md Section 9.5)
|
|
# =============================================================================
|