[rocm-libraries] ROCm/rocm-libraries#5260 (commit a1834d2)

[CK] [CK_Tile] Add FMHA scaffolding to CK kernel dispatcher
 (#5260)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

The CK Tile dispatcher currently supports GEMM and Grouped Convolution
but has no support for Fused Multi-Head Attention (FMHA). The
example/ck_tile/01_fmha folder contains a comprehensive FMHA
implementation with forward, backward, split-KV, paged-KV, append-KV,
and batch-prefill kernels across multiple GPU architectures — but there
is no unified dispatch layer for it. This PR ports the FMHA stack into
the dispatcher, following the same architectural patterns established by
GEMM and Grouped Convolution, enabling runtime kernel selection, JIT
compilation from Python, and a declarative C++ example flow. Autotuning
heuristics to follow.

## Technical Details

This PR adds FMHA scaffolding to the CK dispatcher framework, mirroring
GEMM's layered architecture. Seven new C++ runtime headers provide type
definitions (coexisting with upstream headers via __has_include,
requiring zero modifications to example/ck_tile/01_fmha/), a problem
builder with 18+ setters, Signature + Algorithm kernel key matching, a
virtual kernel instance, a DECL_FMHA_KERNEL_SET macro with wildcard
support and named tile/wave/warp setters, arch-aware registry with JSON
export, and a dispatcher with seqtune-aware selection, configurable
timing, and multi-stage execution plans for split-KV (two-stage) and
backward (three-stage). The codegen pipeline is driven by a
fmha_arch_specs.json capturing per-arch tile tables and pipeline
constraints for five architectures (gfx90a/942/950/1100/1201), migrated
from hardcoded logic in 01_fmha/codegen/, with supporting modules for
C++ symbol mappings, validation rules, and named receipt profiles
(ck_default, flash, pytorch, aiter, fp32, fp8). Python integration
(fmha_utils.py) mirrors the C++ layer with JIT compilation, parallel
multi-kernel builds, HIP memory management via ctypes, tolerance-based
validation, and a NumPy CPU reference with GQA support. Twenty-seven C++
and thirty-two Python examples cover the full feature surface — forward,
split-KV, masks, bias, dropout, GQA, backward, append-KV, batch prefill,
fp8, logits soft cap, sink tokens, and parameter sweeps — all
JIT-compiled on the fly.

## Test Plan

Seven test files cover the runtime types, codegen, and end-to-end
correctness. C++ unit tests validate the problem builder, dispatcher
planning (single-stage for forward/paged-KV/append-KV; multi-stage for
split-KV and backward), registry operations, and the kernel-set
declaration macro. Python unit tests verify codegen emission, profile
filtering, and 15 validation rules for masks, hdim constraints, and
pipeline requirements. GPU execution validation in 01_basic_fmha
--validate reports zero errors across 65,536 elements with max absolute
error of 7.29e-05. A gold-standard parity suite (test_fmha_parity.py)
runs 14 configurations through both the upstream tile_example_fmha_fwd
and the dispatcher, comparing exit codes to confirm behavioral parity —
all 14 match.

## Test Result

The C++ smoke test builds and passes all 9 compiled examples, and a
Python JIT sweep (29_sweep_seqlen.py) passes 7/7 configurations reaching
up to 375 TFLOPS at seqlen 2048.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Vidyasagar Ananthan
2026-05-17 07:30:33 +00:00
committed by assistant-librarian[bot]
parent 61b019f2a2
commit 86591de476
148 changed files with 41250 additions and 87 deletions

View File

@@ -0,0 +1,371 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
//
// Example 01: Basic FMHA Forward with GPU Execution
//
// Demonstrates the full flow:
// 1. Declare kernels via DECL_FMHA_KERNEL_SET
// 2. Register and plan
// 3. Allocate Q, K, V, O GPU buffers
// 4. Run the FMHA forward kernel on GPU
// 5. Copy output to host and validate against CPU reference
//
// Mirrors 01_basic_gemm.cpp for FMHA.
#include <hip/hip_runtime.h>
#include <cmath>
#include <iomanip>
#include <iostream>
#include <random>
#include <vector>
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::utils;
// FMHA tile/wave/warp dimensions correspond to TWO GEMM stages:
// Stage 0 (Q * K^T): tile_m0 x tile_n0 x tile_k0 (seqlen_q x seqlen_k x hdim_q)
// Stage 1 (Attn * V): tile_m0 x tile_n1 x tile_k1 (seqlen_q x hdim_v x seqlen_k)
// Wave/warp follow the same stage pattern: *_m0/n0/k0 for stage 0, *_m1/n1/k1 for stage 1.
DECL_FMHA_KERNEL_SET(basic_fmha_kernels,
.add(FmhaSignature()
.family("fwd")
.dtype("fp16")
.mode("batch")
.vlayout("r") // V row-major
.hdim(128) // hdim_q = hdim_v = 128
.mask("no")
.bias("no")
.lse(false)
.dropout(false)
.qscale("no"),
FmhaAlgorithm()
// Stage 0 tile: seqlen_q=128, seqlen_k=128, hdim_q=32
.tile_m0(128)
.tile_n0(128)
.tile_k0(32)
// Stage 1 tile: hdim_v=128, seqlen_k=32, alignment=128
.tile_n1(128)
.tile_k1(32)
.tile_k0max(128)
// Wave: 4 warps on m, 1 on n, 1 on k (both stages)
.wave_m0(4)
.wave_n0(1)
.wave_k0(1)
.wave_m1(4)
.wave_n1(1)
.wave_k1(1)
// Warp tile: 32x32x16 (both stages)
.warp_m0(32)
.warp_n0(32)
.warp_k0(16)
.warp_m1(32)
.warp_n1(32)
.warp_k1(16)
.pipeline("qr_async")
.padding(true, true, true, true) // pad_s, pad_sk, pad_d, pad_dv
.alignments(128, 128) // hdim_q_alignment, hdim_v_alignment
.selection_rank(0),
"gfx950"));
namespace {
using FmhaDataType = ck_tile::fp16_t;
void cpu_attention_fwd(const std::vector<float>& Q,
const std::vector<float>& K,
const std::vector<float>& V,
std::vector<float>& O,
int batch,
int nhead,
int seqlen_q,
int seqlen_k,
int hdim_q,
int hdim_v,
float scale)
{
for(int b = 0; b < batch; ++b)
{
for(int h = 0; h < nhead; ++h)
{
for(int sq = 0; sq < seqlen_q; ++sq)
{
std::vector<float> scores(seqlen_k, 0.0f);
float max_score = -1e30f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
float dot = 0.0f;
for(int d = 0; d < hdim_q; ++d)
{
int q_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_q + d;
int k_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_q + d;
dot += Q[q_idx] * K[k_idx];
}
scores[sk] = dot * scale;
max_score = std::max(max_score, scores[sk]);
}
float sum_exp = 0.0f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
scores[sk] = std::exp(scores[sk] - max_score);
sum_exp += scores[sk];
}
for(int sk = 0; sk < seqlen_k; ++sk)
{
scores[sk] /= sum_exp;
}
for(int dv = 0; dv < hdim_v; ++dv)
{
float acc = 0.0f;
for(int sk = 0; sk < seqlen_k; ++sk)
{
int v_idx = ((b * nhead + h) * seqlen_k + sk) * hdim_v + dv;
acc += scores[sk] * V[v_idx];
}
int o_idx = ((b * nhead + h) * seqlen_q + sq) * hdim_v + dv;
O[o_idx] = acc;
}
}
}
}
}
} // namespace
int main(int argc, char* argv[])
{
ExampleArgs args("Example 01: FMHA Forward (GPU Execution)", "FMHA with real GPU data");
args.add_option("--arch", "gfx950", "GPU architecture");
args.add_option("--batch", "2", "Batch size");
args.add_option("--nhead", "4", "Number of heads");
args.add_option("--seqlen", "64", "Sequence length (Q and K)");
args.add_option("--hdim", "128", "Head dimension");
args.add_flag("--validate", "Validate against CPU reference");
if(!args.parse(argc, argv))
return 0;
const std::string gfx_arch = args.get("--arch", "gfx950");
const int batch = args.get_int("--batch", 2);
const int nhead = args.get_int("--nhead", 4);
const int seqlen = args.get_int("--seqlen", 64);
const int hdim = args.get_int("--hdim", 128);
print_header("Example 01: FMHA Forward (GPU Execution)");
// Step 1: Register kernels
std::cout << "\nStep 1: Register Kernels\n";
FmhaKernelSetRegistry::instance().print();
FmhaRegistry registry;
registry.set_name("basic_fmha");
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
std::cout << " Registered " << registry.size() << " kernel(s)\n";
FmhaDispatcher dispatcher(&registry);
dispatcher.set_benchmarking(true);
dispatcher.set_timing(1, 3);
// Step 2: Plan
const float scale = 1.0f / std::sqrt(static_cast<float>(hdim));
fmha_fwd_traits traits{};
traits.hdim_q = hdim;
traits.hdim_v = hdim;
traits.data_type = "fp16";
traits.is_group_mode = false;
traits.is_v_rowmajor = true;
traits.has_logits_soft_cap = false;
traits.mask_type = mask_enum::no_mask;
traits.bias_type = bias_enum::no_bias;
traits.has_lse = false;
traits.has_dropout = false;
traits.qscale_type = quant_scale_enum::no_scale;
const int64_t q_elems = static_cast<int64_t>(batch) * nhead * seqlen * hdim;
const int64_t k_elems = q_elems;
const int64_t v_elems = q_elems;
const int64_t o_elems = q_elems;
// Step 3: Allocate GPU buffers
std::cout << "\nStep 2: Allocate GPU Buffers\n";
std::cout << " Q/K/V/O: [" << batch << ", " << nhead << ", " << seqlen << ", " << hdim
<< "]\n";
GpuBuffer<FmhaDataType> q_dev(q_elems);
GpuBuffer<FmhaDataType> k_dev(k_elems);
GpuBuffer<FmhaDataType> v_dev(v_elems);
GpuBuffer<FmhaDataType> o_dev(o_elems);
// Fill Q, K, V with random data
std::mt19937 rng(42);
std::uniform_real_distribution<float> dist(-0.5f, 0.5f);
std::vector<FmhaDataType> q_host(q_elems);
std::vector<FmhaDataType> k_host(k_elems);
std::vector<FmhaDataType> v_host(v_elems);
for(auto& x : q_host)
x = FmhaDataType(dist(rng));
for(auto& x : k_host)
x = FmhaDataType(dist(rng));
for(auto& x : v_host)
x = FmhaDataType(dist(rng));
q_dev.copy_from_host(q_host.data());
k_dev.copy_from_host(k_host.data());
v_dev.copy_from_host(v_host.data());
o_dev.zero();
// Step 4: Set up args with device pointers and strides
fmha_fwd_args fmha_args{};
fmha_args.q_ptr = q_dev.get();
fmha_args.k_ptr = k_dev.get();
fmha_args.v_ptr = v_dev.get();
fmha_args.o_ptr = o_dev.get();
fmha_args.bias_ptr = nullptr;
fmha_args.q_descale_ptr = nullptr;
fmha_args.k_descale_ptr = nullptr;
fmha_args.v_descale_ptr = nullptr;
fmha_args.rand_val_ptr = nullptr;
fmha_args.lse_ptr = nullptr;
fmha_args.sink_ptr = nullptr;
fmha_args.block_scale_seqstart_q_ptr = nullptr;
fmha_args.block_scale_seqstart_k_ptr = nullptr;
fmha_args.seqlen_q = seqlen;
fmha_args.seqlen_k = seqlen;
fmha_args.batch = batch;
fmha_args.max_seqlen_q = seqlen;
fmha_args.hdim_q = hdim;
fmha_args.hdim_v = hdim;
fmha_args.nhead_q = nhead;
fmha_args.nhead_k = nhead;
fmha_args.scale_s = scale;
fmha_args.logits_soft_cap = 0.0f;
// bhsd layout strides
fmha_args.stride_q = hdim;
fmha_args.stride_k = hdim;
fmha_args.stride_v = hdim;
fmha_args.stride_bias = 0;
fmha_args.stride_randval = 0;
fmha_args.stride_o = hdim;
fmha_args.nhead_stride_q = seqlen * hdim;
fmha_args.nhead_stride_k = seqlen * hdim;
fmha_args.nhead_stride_v = seqlen * hdim;
fmha_args.nhead_stride_bias = 0;
fmha_args.nhead_stride_randval = 0;
fmha_args.nhead_stride_lse = 0;
fmha_args.nhead_stride_o = seqlen * hdim;
fmha_args.nhead_stride_q_descale = 0;
fmha_args.nhead_stride_k_descale = 0;
fmha_args.nhead_stride_v_descale = 0;
fmha_args.batch_stride_q = nhead * seqlen * hdim;
fmha_args.batch_stride_k = nhead * seqlen * hdim;
fmha_args.batch_stride_v = nhead * seqlen * hdim;
fmha_args.batch_stride_bias = 0;
fmha_args.batch_stride_randval = 0;
fmha_args.batch_stride_lse = 0;
fmha_args.batch_stride_o = nhead * seqlen * hdim;
fmha_args.batch_stride_q_descale = 0;
fmha_args.batch_stride_k_descale = 0;
fmha_args.batch_stride_v_descale = 0;
fmha_args.window_size_left = -1;
fmha_args.window_size_right = -1;
fmha_args.sink_size = 0;
fmha_args.mask_type = 0;
fmha_args.min_seqlen_q = 0;
fmha_args.p_drop = 0.0f;
fmha_args.s_randval = false;
fmha_args.drop_seed_offset = std::make_pair(uint64_t(0), uint64_t(0));
fmha_args.block_scale_size_q = 0;
fmha_args.block_scale_size_kv = 0;
// Step 5: Run on GPU
std::cout << "\nStep 3: Run FMHA Forward on GPU\n";
float time_ms = 0.0f;
try
{
time_ms = dispatcher.run_fwd(traits, fmha_args, nullptr);
}
catch(const std::exception& e)
{
std::cerr << " ERROR: " << e.what() << "\n";
return 1;
}
auto problem = FmhaProblem::from_invocation(FmhaInvocation::make(traits, fmha_args), gfx_arch);
double tflops = static_cast<double>(problem.num_ops()) / (time_ms * 1e-3) / 1e12;
std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n";
std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n";
// Step 6: Copy output and validate
std::cout << "\nStep 4: Validate\n";
std::vector<FmhaDataType> o_host(o_elems);
o_dev.copy_to_host(o_host.data());
// Quick sanity check: output should be non-zero
int nonzero = 0;
for(int64_t i = 0; i < o_elems; ++i)
{
if(static_cast<float>(o_host[i]) != 0.0f)
++nonzero;
}
std::cout << " Non-zero outputs: " << nonzero << " / " << o_elems << "\n";
bool passed = (nonzero > 0);
if(args.has("--validate"))
{
// CPU reference
std::vector<float> q_f32(q_elems), k_f32(k_elems), v_f32(v_elems), o_ref(o_elems, 0.0f);
for(int64_t i = 0; i < q_elems; ++i)
q_f32[i] = static_cast<float>(q_host[i]);
for(int64_t i = 0; i < k_elems; ++i)
k_f32[i] = static_cast<float>(k_host[i]);
for(int64_t i = 0; i < v_elems; ++i)
v_f32[i] = static_cast<float>(v_host[i]);
cpu_attention_fwd(
q_f32, k_f32, v_f32, o_ref, batch, nhead, seqlen, seqlen, hdim, hdim, scale);
double max_abs_err = 0.0;
double max_rel_err = 0.0;
int errors = 0;
const double rtol = 1e-2;
const double atol = 1e-2;
for(int64_t i = 0; i < o_elems; ++i)
{
float gpu_val = static_cast<float>(o_host[i]);
float ref_val = o_ref[i];
double abs_err = std::abs(gpu_val - ref_val);
double rel_err = abs_err / (std::abs(ref_val) + 1e-6);
max_abs_err = std::max(max_abs_err, abs_err);
max_rel_err = std::max(max_rel_err, rel_err);
if(abs_err > atol + rtol * std::abs(ref_val))
++errors;
}
std::cout << " Max abs error: " << std::scientific << max_abs_err << "\n";
std::cout << " Max rel error: " << max_rel_err << "\n";
std::cout << " Errors: " << errors << " / " << o_elems << "\n";
passed = (errors == 0);
}
print_separator();
std::cout << "Status: " << (passed ? "PASS" : "FAIL") << "\n";
print_separator();
return passed ? 0 : 1;
}