Files
composable_kernel/example/ck_tile/01_fmha/example_fmha_bwd.cpp
Aviral Goel d85f065b15 chore(copyright): update copyright header for example directory (#3273)
* chore(copyright): update copyright header for codegen directory

* chore(copyright): update copyright header for example directory
2025-11-24 18:02:41 -08:00

200 lines
8.9 KiB
C++

// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "fmha_bwd.hpp"
#include "fmha_bwd_runner.hpp"
#include <string>
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("v", "1", "whether do CPU validation or not")
.insert("mode", "0", "kernel mode. 0:batch, 1:group")
.insert("b", "2", "batch size")
.insert("h", "8", "num of head, for q")
.insert("h_k",
"-1",
"num of head, for k/v, -1 means equal to h\n"
"if not equal to h, then this is GQA/MQA case")
.insert("s",
"3328",
"seqlen_q. if group-mode, means the average value of seqlen_q\n"
"total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary\n"
"also with \"-s=s0,s1,s2...\" comma-separated ints to set seqlen per batch "
"(group mode)")
.insert("s_qpad",
"-1",
"padded seqlen_q per batch (group mode only). "
"Use \"-s_qpad=p0,p1,...\"; -1 disables explicit padding")
.insert("s_k",
"-1",
"seqlen_k, -1 means equal to s\n"
"also with \"-s_k=s0,s1,s2...\" comma-separated ints to set seqlen per batch "
"(group mode)")
.insert("s_kpad",
"-1",
"padded seqlen_k per batch (group mode only). "
"Use \"-s_kpad=k0,k1,...\"; -1 disables explicit padding")
.insert("d", "128", "head dim for q, k")
.insert("d_v", "-1", "head dim for v, -1 means equal to d")
.insert("scale", "0", "scale factor. 0 means equal to 1/sqrt(hdim)")
.insert("iperm",
"1",
"permute input\n"
"if true, will be b*h*s*d, else b*s*h*d")
.insert("operm", "1", "permute output")
.insert("bias",
"n",
"n or 0, no bias\n"
"e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n"
"a(libi) or 2, alibi with 1*h. a:1, b*h")
.insert("dbias", "0", "output bias gradient or not")
.insert("prec", "fp16", "data type. fp32/fp16/bf16")
.insert("mask",
"0",
"0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n"
"'t', top-left causal mask, 'b', bottom-r causal mask\n"
"'t:l,r', top-left sliding window attn(swa) with FA style left right size\n"
"'b:l,r', bottom-r sliding window attn(swa) with FA style left right size\n"
"'xt:window_size', xformer style masking from top-left, window_size negative is "
"causal, positive is swa\n"
"'xb:window_size', xformer style masking from bottom-r, window_size negative is "
"causal, positive is swa\n"
"'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for "
"now)")
.insert("kname", "0", "if set to 1 will print kernel name")
.insert("init",
"uf",
"init method:\n ui or 0 - uniform random int\n uf or 1 - uniform random float"
"\n tf or 2 - trig float")
.insert("seed",
"11939",
"random seed used for initializing input tensors. 0 for "
"non-deterministic seed")
.insert("p_drop", "0", "0~1 probability of dropout")
.insert("drop_seed", "1", "seed for dropout random number generator")
.insert("drop_offset", "0", "offset for dropout random number generator")
.insert(
"drop_prefs",
"0",
"whether dropout seed and offset values are present on GPU; 0 - host, 1 - device/GPU")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("warmup", "5", "number of iterations before benchmark the kernel")
.insert("repeat", "20", "number of iterations to benchmark the kernel")
.insert("deterministic",
"0",
"if set to 1 will use multi-buffer reduction strategy for dq, atomic operation "
"will not be used")
.insert("json", "0", "0: No Json, 1: Dump Results in Json format")
.insert("jsonfile", "fmha_bwd.json", "json file name to dump results");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename DataTypeConfig>
auto run(const ck_tile::ArgParser& arg_parser)
{
std::string data_type = arg_parser.get_str("prec");
int do_validation = arg_parser.get_int("v");
mode_enum mode = static_cast<mode_enum>(arg_parser.get_uint32("mode"));
ck_tile::index_t batch = arg_parser.get_int("b");
ck_tile::index_t nhead = arg_parser.get_int("h");
ck_tile::index_t nhead_k = arg_parser.get_int("h_k");
auto seqlen_qs = arg_parser.get_int_vec("s");
auto seqlen_qpads = arg_parser.get_int_vec("s_qpad");
auto seqlen_ks = arg_parser.get_int_vec("s_k");
auto seqlen_kpads = arg_parser.get_int_vec("s_kpad");
ck_tile::index_t hdim_q = arg_parser.get_int("d");
ck_tile::index_t hdim_v = arg_parser.get_int("d_v");
bool i_perm = arg_parser.get_bool("iperm");
bool o_perm = arg_parser.get_bool("operm");
float scale = arg_parser.get_float("scale");
std::string bias_str = arg_parser.get_str("bias");
bool use_dbias = arg_parser.get_bool("dbias");
float p_drop = arg_parser.get_float("p_drop");
uint64_t drop_seed = arg_parser.get_uint64("drop_seed");
uint64_t drop_offset = arg_parser.get_uint64("drop_offset");
bool drop_prefs = arg_parser.get_bool("drop_prefs");
std::string mask_str = arg_parser.get_str("mask");
bool deterministic = arg_parser.get_bool("deterministic");
std::string init_method = arg_parser.get_str("init");
uint32_t seed = arg_parser.get_uint32("seed");
ck_tile::stream_config stream_config{nullptr,
true,
/* log_level = */ (arg_parser.get_bool("kname") ? 1 : 0),
arg_parser.get_int("warmup"),
arg_parser.get_int("repeat"),
arg_parser.get_str("timer") == std::string("gpu")};
auto json = arg_parser.get_int("json") == 1
? std::optional<std::string>{arg_parser.get_str("jsonfile")}
: std::nullopt;
return fmha_bwd_run<DataTypeConfig>(mode,
batch,
nhead,
nhead_k,
seqlen_qs,
seqlen_ks,
seqlen_qpads,
seqlen_kpads,
hdim_q,
hdim_v,
i_perm,
o_perm,
scale,
bias_str,
use_dbias,
p_drop,
drop_seed,
drop_offset,
drop_prefs,
mask_str,
deterministic,
init_method,
seed,
do_validation,
stream_config,
json);
}
int main(int argc, char* argv[])
{
try
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp32")
{
return run<FmhaBwdFp32>(arg_parser) == bwd_result::success ? 0 : -2;
}
else if(data_type == "fp16")
{
return run<FmhaBwdFp16>(arg_parser) == bwd_result::success ? 0 : -2;
}
else if(data_type == "bf16")
{
return run<FmhaBwdBf16>(arg_parser) == bwd_result::success ? 0 : -2;
}
std::cerr << "Unsupported precision: " << data_type << std::endl;
return -1;
}
catch(const std::invalid_argument& e)
{
std::cerr << "Invalid argument: " << e.what() << std::endl;
return -1;
}
catch(const std::exception& e)
{
std::cerr << "Error: " << e.what() << std::endl;
return -2;
}
}