Files
composable_kernel/example/ck_tile/01_fmha/example_fmha_fwd.cpp
Anton Gorenko ec006bb8e0 [CK_TILE] Add gtests for FMHA (#2744)
* Improve random number generation

* use different seed for each input (Q, K, V...);
* use deterministic generation of:
  * seqstart_q/k (for group mode);
  * block_table (for paged-kvcahe);
  * cache_batch_idx (for kvcache);

* Extract arg_parser-related code from run functions to use them as tests

* Split examples into main programs and fmha runners, build instances separately

* Add dummy tests that use instances and runners

* Fix a missed corner case of f32->f8 conversion

When value if < min f8 denormal but > min f8 denormal / 2, it must be
rounded to min f8 denormal (i.e. 0b1), not to 0.

* Fix incorrect fp8 scales for P and O in validation code

DataTypeConfig was incorrectly compared with fp8_t.

* Add host generation of dropout random values and use it for validation

Previously host validation (reference_batched_dropout) used random
numbers generated by BlockDropout of the kernel, meaning that incorrect
generation on device (bad distribution, repeated numbers, too many zeros,
etc.) would not trigger any validation errors.

* Implement tests from smoke_test_bwd.sh

* Return result as enum to distinguish failure and missing instance

* Add tests for bwd features: bias, alibi, dropout

* Implement tests from smoke_test_fwd.sh

* Pass seqlen_q/k as vectors to fwd and bwd runners

* Add tests for fwd features: bias, alibi, dropout

* Add tests for pagedkv and splitkv

* Fix conditions when to use splitkv and pagedkv kernels

splitkv was executed only when use_kvcache which == (need_append_kvcache || use_cache_batch_idx || 0 < page_block_size).
In the SplitKV tests: the regular fwd kernel was executed if use_cache_batch_idx was not requested even when num_splitkv > 1.
In the AppendKV tests: the pagedkv kernel was executed but it often failed to find an instance.

* Add tests for appendkv

* Use is_v_rowmajor = true because there are no instances with column layout anymore

* Split public and private compile options for instances

Tests and examples need to know only about CK_TILE_FMHA_FWD_*_API.

* Improve parsing validation in bias and mask

* Pass bias as string for consistency with mask

* Catch parsing and other exceptions

* Add bwd test for deterministic flag

* Initialize fp8 tensors (-init=ufq) similarly to uf

* Fix splitkv/pagedkv invocation: use padded sk when seqlen_k_ptr is not null

seqlen_k cannot be used to determine padding when seqlen_k_ptr is
provided. The actual seqlen_k is taken from seqlen_k_ptr[b].
Even seqlen_k values (% bn0 == 0) use padded seqlen_k while seqlen_k_ptr
may contain arbitrary values.
In the example or tests this produces incorrect results with appendkv
(for example, -d=32 -s=1 -s_k=64 -s_knew=7 -vlayout=c -b=8).

* Fix use_pagedkv value when kvcache = true but page_block_size = 0

In this case block_table_ptr is nullptr which is accessed in the kernel.

* Clean up bwd tests

* Unify fwd tests for f16/bf16 and fp8

* Use better explicit instantiation declaration for fmha_bwd<2>

* Use the same seed for all tests, allow to override it with env variable

* Undo clang-format of one irrelevant file

For some reason my local clang-format-18 and the one in CI work differently.

* Do not build instances and tests on unsupported archs

* Build instance libraries as OBJECT library

* CI: Enable sccache for HIP

There are source files with LANGUAGE HIP, they need
-DCMAKE_HIP_COMPILER_LAUNCHER=sccache

* Add tests to REGRESSION_TESTS

* Fix OOB accesses in deterministic bwd due to incorrectly assumed kN0

The runner assumes kN0 = (hdim_q <= 128) ? 128 : 64 but there are
smaller tiles (for tr_load or fp32). This can create too small dq_acc_buf.

* Pass CK_TILE_FMHA_FWD_*_API as INTERFACE compile options

The instances don't actually depend on them, only examples and tests do.
Passing these definitions as INTERFACE allows to change FMHA_FWD_ENABLE_APIS
without recompiling instances that are already in ccache.

* Fix formatting and names
2025-09-10 08:06:14 +05:00

254 lines
12 KiB
C++

// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/host.hpp"
#include "fmha_fwd.hpp"
#include "fmha_fwd_runner.hpp"
#include <string>
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("v", "1", "0:no validation, 2:cpu validation, 2:gpu validation(experimental)")
.insert("mode", "0", "kernel mode. 0:batch, 1:group")
.insert("b", "2", "batch size")
.insert("h", "8", "num of head, for q")
.insert("h_k",
"-1",
"num of head, for k/v, -1 means equal to h\n"
"if not equal to h, then this is GQA/MQA case")
.insert("s",
"3328",
"seqlen_q. if group-mode, means the average value of seqlen_q\n"
"total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary\n"
"also with \"-s=s0,s1,s2...\" comma-separated ints to set seqlen per batch "
"(group mode)")
.insert("s_k",
"-1",
"seqlen_k (including new key/value), -1 means equal to s\n"
"also with \"-s_k=s0,s1,s2...\" comma-separated ints to set seqlen per batch "
"(group mode)")
.insert("s_knew",
"0",
"seqlen_k for new key/value, 0 means not to use this at all; "
"-1 to choose s_knew in [1, s] randomly.")
.insert("s_kpad",
"-1",
"seqlen_k stride between 2 batches, currently used in group-mode only\n"
"for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride\n"
"along seqlen, instead of packed, same as xformer kv_padding,\n"
"must be greater than or equal to s_k")
.insert("d", "128", "head dim for q, k")
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
.insert("scale_s",
"0",
"scale factor of S. 0 means equal to 1/sqrt(hdim).\n"
"note when squant=1, this value will be modified by range_q/k")
.insert("logits_soft_cap", "0", "attention logits soft capping value.")
.insert("range_q", "16", "per-tensor quantization range of q. used if squant=1.")
.insert("range_k", "16", "per-tensor quantization range of k. used if squant=1.")
.insert("range_v", "16", "per-tensor quantization range of v. used if squant=1.")
.insert("range_p", "1", "per-tensor quantization range of p [e^(s-m)]. used if squant=1.")
.insert("range_o", "16", "per-tensor quantization range of o (p*v). used if squant=1.")
.insert("squant",
"auto",
"if using static quantization fusion or not. auto: fp8 will default use squant, "
"other will not\n"
"0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to "
"P and O.\n"
"calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, "
"range_p, range_o")
.insert("iperm",
"1",
"permute input\n"
"if true, will be b*h*s*d, else b*s*h*d")
.insert("operm", "1", "permute output")
.insert("bias",
"n",
"n or 0, no bias\n"
"e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n"
"a(libi) or 2, alibi with 1*h. a:1, b*h")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("mask",
"0",
"0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n"
"'t', top-left causal mask, 'b', bottom-r causal mask\n"
"'t:l,r', top-left sliding window attn(swa) with FA style left right size\n"
"'b:l,r', bottom-r sliding window attn(swa) with FA style left right size\n"
"'xt:window_size', xformer style masking from top-left, window_size negative is "
"causal, positive is swa\n"
"'xb:window_size', xformer style masking from bottom-r, window_size negative is "
"causal, positive is swa\n"
"'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for "
"now)")
.insert("vlayout", "r", "r for row-major(seqlen*hdim), c for col-major(hdim*seqlen)")
.insert("lse", "0", "0 not store lse, 1 store lse")
.insert("kname", "0", "if set to 1 will print kernel name")
.insert("init",
"uf",
"init method:\n ui or 0 - uniform random int\n ni - normalized random int"
"\n uf or 1 - uniform random float\n nf - normalized random float"
"\n tf or 2 - trig float\n uf:q or ufq or 3 - fp8 quantization")
.insert("seed",
"11939",
"random seed used for initializing input tensors. 0 for "
"non-deterministic seed")
.insert("p_drop", "0", "0~1 probability of dropout")
.insert("drop_seed", "1", "seed for dropout random number generator")
.insert("drop_offset", "0", "offset for dropout random number generator")
.insert(
"drop_prefs",
"0",
"whether dropout seed and offset values are present on GPU; 0 - host, 1 - device/GPU")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert(
"rotary_dim", "0", "RoPE rotary dimension. rotary_dim <= 0 means not apply RoPE at all")
.insert("rotary_interleaved", "1", "whether to apply interleaved RoPE")
.insert("num_splits",
"1",
"# of splits for key/value. 0 to determine actual number by heuristic")
.insert("page_block_size", "0", "paged-kvcache block size. 0 means not use paged-kvcahe")
.insert("cache_batch_idx", "0", "whether to use index map to the kvcache")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel")
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
.insert("jsonfile", "fmha_fwd.json", "json file name to dump results");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename DataTypeConfig>
auto run(const ck_tile::ArgParser& arg_parser)
{
int do_validation = arg_parser.get_int("v");
mode_enum mode = static_cast<mode_enum>(arg_parser.get_uint32("mode"));
ck_tile::index_t batch = arg_parser.get_int("b");
ck_tile::index_t nhead = arg_parser.get_int("h");
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
auto seqlen_qs = arg_parser.get_int_vec("s");
auto seqlen_ks = arg_parser.get_int_vec("s_k");
ck_tile::index_t hdim_q = arg_parser.get_int("d");
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
ck_tile::index_t seqlen_knew = arg_parser.get_int("s_knew");
auto seqlen_kpads = arg_parser.get_int_vec("s_kpad");
ck_tile::index_t rotary_dim = arg_parser.get_int("rotary_dim");
bool i_perm = arg_parser.get_bool("iperm");
bool o_perm = arg_parser.get_bool("operm");
float scale_s = arg_parser.get_float("scale_s");
float logits_soft_cap = arg_parser.get_float("logits_soft_cap");
bool is_v_rowmajor = arg_parser.get_str("vlayout") == "r";
bool lse = arg_parser.get_bool("lse");
ck_tile::index_t page_block_size = arg_parser.get_int("page_block_size");
bool use_cache_batch_idx = arg_parser.get_bool("cache_batch_idx");
std::string bias_str = arg_parser.get_str("bias");
float p_drop = arg_parser.get_float("p_drop");
uint64_t drop_seed = arg_parser.get_uint64("drop_seed");
uint64_t drop_offset = arg_parser.get_uint64("drop_offset");
bool drop_prefs = arg_parser.get_bool("drop_prefs");
std::string mask_str = arg_parser.get_str("mask");
float range_q = arg_parser.get_float("range_q");
float range_k = arg_parser.get_float("range_k");
float range_v = arg_parser.get_float("range_v");
float range_p = arg_parser.get_float("range_p");
float range_o = arg_parser.get_float("range_o");
bool is_rotary_interleaved = arg_parser.get_bool("rotary_interleaved");
ck_tile::index_t num_splits = arg_parser.get_int("num_splits");
std::string init_method = arg_parser.get_str("init");
uint32_t seed = arg_parser.get_uint32("seed");
bool squant = [&]() {
if(arg_parser.get_str("squant") == "auto")
return std::is_same_v<DataTypeConfig, FmhaFwdFp8>;
else
return arg_parser.get_bool("squant");
}();
ck_tile::stream_config stream_config{nullptr,
true,
/* log_level = */ (arg_parser.get_bool("kname") ? 1 : 0),
arg_parser.get_int("warmup"),
arg_parser.get_int("repeat"),
arg_parser.get_str("timer") == std::string("gpu")};
auto json = arg_parser.get_int("json") == 1
? std::optional<std::string>{arg_parser.get_str("jsonfile")}
: std::nullopt;
return fmha_fwd_run<DataTypeConfig>(mode,
batch,
nhead,
nhead_k,
seqlen_qs,
seqlen_ks,
hdim_q,
hdim_v,
seqlen_knew,
seqlen_kpads,
rotary_dim,
i_perm,
o_perm,
scale_s,
logits_soft_cap,
is_v_rowmajor,
lse,
page_block_size,
use_cache_batch_idx,
bias_str,
p_drop,
drop_seed,
drop_offset,
drop_prefs,
mask_str,
range_q,
range_k,
range_v,
range_p,
range_o,
squant,
is_rotary_interleaved,
num_splits,
init_method,
seed,
do_validation,
stream_config,
json);
}
int main(int argc, char* argv[])
{
try
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16")
{
return run<FmhaFwdFp16>(arg_parser) == fwd_result::success ? 0 : -2;
}
else if(data_type == "bf16")
{
return run<FmhaFwdBf16>(arg_parser) == fwd_result::success ? 0 : -2;
}
else if(data_type == "fp8")
{
return run<FmhaFwdFp8>(arg_parser) == fwd_result::success ? 0 : -2;
}
std::cerr << "Unsupported precision: " << data_type << std::endl;
return -1;
}
catch(const std::invalid_argument& e)
{
std::cerr << "Invalid argument: " << e.what() << std::endl;
return -1;
}
catch(const std::exception& e)
{
std::cerr << "Error: " << e.what() << std::endl;
return -2;
}
}