cover more instances and change test code to do benchmarking

This commit is contained in:
Mohsen Saffari
2026-03-02 14:50:19 +00:00
parent d5acfd8d52
commit 27a99edef9
5 changed files with 427 additions and 121 deletions

View File

@@ -146,6 +146,9 @@ FMHA_FWD_API_HEADER = """
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n
// auto generated by generate.py
#include <cstdio>
#include <algorithm>
#include <string>
#include <vector>
#include <hip/hip_runtime.h>
@@ -182,7 +185,16 @@ unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seq
FMHA_FWD_API_FUNC_TEMPLATE = """
namespace {{
float {F_func_name}([[maybe_unused]] fmha_fwd_traits t, [[maybe_unused]] fmha_fwd_args a, [[maybe_unused]] const ck_tile::stream_config& s) {{
struct run_all_result_t {{
float time_ms;
double tflops;
double gbps;
const char* runtime_name;
}};
float r = -1;
[[maybe_unused]] bool output_started = false;
[[maybe_unused]] std::vector<run_all_result_t> run_all_results;
[[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate
@@ -198,6 +210,19 @@ float {F_func_name}([[maybe_unused]] fmha_fwd_traits t, [[maybe_unused]] fmha_fw
[[maybe_unused]] const std::string device_name = ck_tile::get_device_name();
{F_dispatch}
if(t.run_all_kernels && !run_all_results.empty()) {{
std::sort(run_all_results.begin(), run_all_results.end(),
[](const auto& lhs, const auto& rhs) {{ return lhs.time_ms < rhs.time_ms; }});
printf("\\n");
for(const auto& result : run_all_results) {{
printf("%s, %.6f ms, %.2f TFlops, %.2f GB/s\\n",
result.runtime_name,
static_cast<double>(result.time_ms),
result.tflops,
result.gbps);
}}
r = run_all_results.front().time_ms;
}}
return r;
}}
}} // namespace
@@ -238,10 +263,51 @@ FMHA_FWD_API_PER_HDIM_CASE = """{F_if}(t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hd
}}
"""
FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.skip_min_seqlen_q == {F_skip}) &&(t.has_sink == {F_sink}) &&
FMHA_FWD_API_INNER_DISPATCH = """if(t.list_kernels && (t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.skip_min_seqlen_q == {F_skip}) &&(t.has_sink == {F_sink}) &&
({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}, {F_sink}>;
return fmha_fwd_<trait_, {F_arch.tag}>(s, a);
constexpr const char* kernel_name = "{F_kernel_name}";
constexpr const char* runtime_name = "{F_runtime_name}";
const bool kernel_match = t.kernel_filter.empty() ||
(std::string(kernel_name).find(t.kernel_filter) != std::string::npos) ||
(std::string(runtime_name).find(t.kernel_filter) != std::string::npos);
if(kernel_match) {{
if(!output_started) {{
printf("\\n");
output_started = true;
}}
printf("%s | %s\\n", kernel_name, runtime_name);
}}
}}
if(t.run_all_kernels && (t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.skip_min_seqlen_q == {F_skip}) &&(t.has_sink == {F_sink}) &&
({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
constexpr const char* kernel_name = "{F_kernel_name}";
constexpr const char* runtime_name = "{F_runtime_name}";
const bool kernel_match = t.kernel_filter.empty() ||
(std::string(kernel_name).find(t.kernel_filter) != std::string::npos) ||
(std::string(runtime_name).find(t.kernel_filter) != std::string::npos);
if(kernel_match) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}, {F_sink}>;
auto s_run_all = s;
s_run_all.log_level_ = 0;
const float cur = fmha_fwd_<trait_, {F_arch.tag}>(s_run_all, a);
if(cur >= 0) {{
const double tflops = static_cast<double>(t.perf_flop) / 1.0e9 / static_cast<double>(cur);
const double gbps = static_cast<double>(t.perf_num_byte) / 1.0e6 / static_cast<double>(cur);
run_all_results.emplace_back(run_all_result_t{{cur, tflops, gbps, runtime_name}});
}}
}}
}}
{F_if}((!t.list_kernels) && (!t.run_all_kernels) && (t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.skip_min_seqlen_q == {F_skip}) &&(t.has_sink == {F_sink}) &&
({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
constexpr const char* kernel_name = "{F_kernel_name}";
constexpr const char* runtime_name = "{F_runtime_name}";
const bool kernel_match = t.kernel_filter.empty() ||
(std::string(kernel_name).find(t.kernel_filter) != std::string::npos) ||
(std::string(runtime_name).find(t.kernel_filter) != std::string::npos);
if(kernel_match) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}, {F_sink}>;
return fmha_fwd_<trait_, {F_arch.tag}>(s, a);
}}
}}
"""
@@ -288,12 +354,13 @@ class FmhaFwdApiTrait:
skip: str
tr_load: str
sink: str
runtime_name: str
constraint: CppConstraint
@property
def name(self) -> str:
return (
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn1}-{self.bk1}-{self.bk0max}-"
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}-{self.sink}"
)
@@ -577,6 +644,8 @@ class FmhaFwdApiPool:
F_bk0max=trait.bk0max,
F_hdim=hdim,
F_dtype=FWD_DTYPE_MAP[dtype],
F_kernel_name=trait.name,
F_runtime_name=trait.runtime_name,
)
per_hdim_case += FMHA_FWD_API_PER_HDIM_CASE.format(
F_if=if_(i_hdim),
@@ -743,6 +812,7 @@ class FmhaFwdKernel:
skip=self.F_pipeline.F_skip,
tr_load=self.F_pipeline.F_trload,
sink=self.F_pipeline.F_sink,
runtime_name=self.name,
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint,
)
@@ -837,24 +907,9 @@ class CompatibilityRuleFactoryGfx9(CompatibilityRuleFactory):
problem_ctx: ProblemContext, kernel_ctx: KernelContext
) -> bool:
if problem_ctx.dtype != "fp32":
# TODO: update if >=gfx11 archs get qr_async and qr_async_trload support
if kernel_ctx.pipeline.tag in cls._AVAILABLE_PIPELINES and (
(
(problem_ctx.hdim, problem_ctx.hdim_v) == (128, 128)
and kernel_ctx.tile.F_bn0 != 128
)
or (
(problem_ctx.hdim, problem_ctx.hdim_v) != (128, 128)
and kernel_ctx.tile.F_bm0 != 128
)
or (
(problem_ctx.hdim, problem_ctx.hdim_v) == (128, 128)
and kernel_ctx.pipeline.tag != "qr_async"
and kernel_ctx.tile.F_bk0 == 64
)
if kernel_ctx.pipeline.tag in {"qr", "qs"} and (
kernel_ctx.tile.F_bk0 >= problem_ctx.hdim
):
# non qr_async_trload only support km0=128 tile size when hdim is not 128
# non qr_async only support kn0=128 tile size when hdim is 128
return False
return True
@@ -946,9 +1001,20 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9):
( 80, 96) : [FmhaFwdTileSize(128, 128, 16, 96, 32, 80, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
( 96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
(128, 128) : [FmhaFwdTileSize( 16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
FmhaFwdTileSize( 16, 64, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
FmhaFwdTileSize( 16, 128, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
FmhaFwdTileSize( 32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
FmhaFwdTileSize( 32, 64, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
FmhaFwdTileSize( 32, 64, 32, 128, 32, 128, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaFwdTileSize( 32, 128, 32, 128, 32, 128, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaFwdTileSize( 32, 256, 32, 128, 32, 128, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaFwdTileSize( 64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 16, -1, CppConstraint('get_num_blocks(64) <= num_cus')),
FmhaFwdTileSize( 64, 256, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 16, -1),
FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaFwdTileSize( 64, 32, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaFwdTileSize(128, 32, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
FmhaFwdTileSize(128, 256, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
# (160, 160) : [FmhaFwdTileSize(128, 128 , 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
(192, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],

View File

@@ -253,6 +253,9 @@ std::string fmha_fwd_splitkv_combine_get_name_<trait_{F_idx}, {F_arch.tag}>()
FMHA_FWD_SPLITKV_API_FILENAME = "fmha_fwd_splitkv_api.cpp"
FMHA_FWD_SPLITKV_API = """
#include <iostream>
#include <algorithm>
#include <utility>
#include <vector>
template<typename fmha_fwd_splitkv_traits_, typename fmha_fwd_splitkv_combine_traits_, typename Arch>
float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
@@ -270,17 +273,76 @@ float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a
}}
float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const ck_tile::stream_config& s) {{
struct run_all_result_t {{
float time_ms;
double tflops;
double gbps;
std::string kernel_name;
}};
float r = -1;
[[maybe_unused]] bool output_started = false;
[[maybe_unused]] std::vector<run_all_result_t> run_all_results;
[[maybe_unused]] const std::string device_name = ck_tile::get_device_name();
{F_dispatch}
if(t.list_kernels) {{
{F_dispatch_list}
return r;
}}
if(t.run_all_kernels) {{
{F_dispatch_run_all}
if(!run_all_results.empty()) {{
std::sort(run_all_results.begin(), run_all_results.end(),
[](const auto& lhs, const auto& rhs) {{ return lhs.time_ms < rhs.time_ms; }});
std::cout << "\\n";
for(const auto& result : run_all_results) {{
printf("%s, %.6f ms, %.2f TFlops, %.2f GB/s\\n",
result.kernel_name.c_str(),
static_cast<double>(result.time_ms),
result.tflops,
result.gbps);
}}
r = run_all_results.front().time_ms;
}}
return r;
}}
{F_dispatch_exec}
return r;
}}
"""
FMHA_FWD_SPLITKV_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
((a.block_table_ptr != nullptr) == {F_pagedkv}) && (t.has_sink == {F_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
FMHA_FWD_SPLITKV_API_INNER_DISPATCH_LIST = """if((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
((a.block_table_ptr != nullptr) == {F_pagedkv}) && (t.has_sink == {F_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) &&
(t.kernel_filter.empty() ||
(std::string("{F_kernel_name}").find(t.kernel_filter) != std::string::npos) ||
(std::string("{F_runtime_name}").find(t.kernel_filter) != std::string::npos))) {{
if(!output_started) {{
std::cout << "\\n";
output_started = true;
}}
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv},{F_sink}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
if (t.has_lse) {{
if constexpr (std::is_same_v<{F_dtype}, FmhaFwdFp8>) {{
// unsupported(fp8+lse)
}} else {{
std::cout << fmha_fwd_splitkv_get_name_<traits_, {F_arch.tag}>()
<< "\\n";
}}
}} else {{
std::cout << fmha_fwd_splitkv_get_name_<traits_, {F_arch.tag}>()
<< "\\n";
}}
}}
"""
FMHA_FWD_SPLITKV_API_INNER_DISPATCH_EXEC = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
((a.block_table_ptr != nullptr) == {F_pagedkv}) && (t.has_sink == {F_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) &&
(t.kernel_filter.empty() ||
(std::string("{F_kernel_name}").find(t.kernel_filter) != std::string::npos) ||
(std::string("{F_runtime_name}").find(t.kernel_filter) != std::string::npos))) {{
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv},{F_sink}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
// get combine kernel tile sizes
@@ -307,6 +369,52 @@ FMHA_FWD_SPLITKV_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) &&
}}
"""
FMHA_FWD_SPLITKV_API_INNER_DISPATCH_RUN_ALL = """if((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
((a.block_table_ptr != nullptr) == {F_pagedkv}) && (t.has_sink == {F_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) &&
(t.kernel_filter.empty() ||
(std::string("{F_kernel_name}").find(t.kernel_filter) != std::string::npos) ||
(std::string("{F_runtime_name}").find(t.kernel_filter) != std::string::npos))) {{
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv},{F_sink}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
if (t.has_lse) {{
if constexpr (std::is_same_v<{F_dtype}, FmhaFwdFp8>) {{
// unsupported(fp8+lse)
}} else {{
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bn1comb}, true, {F_squant}, {F_spad}, {F_dvpad}>;
auto s_run_all = s;
s_run_all.log_level_ = 0;
const float cur = fmha_fwd_splitkv_<traits_, traits2_, {F_arch.tag}>(s_run_all, a);
if(cur >= 0) {{
const double tflops = static_cast<double>(t.perf_flop) / 1.0e9 / static_cast<double>(cur);
const double gbps = static_cast<double>(t.perf_num_byte) / 1.0e6 / static_cast<double>(cur);
run_all_results.emplace_back(
run_all_result_t{{
cur,
tflops,
gbps,
fmha_fwd_splitkv_get_name_<traits_, {F_arch.tag}>(),
}});
}}
}}
}} else {{
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bn1comb}, false, {F_squant}, {F_spad}, {F_dvpad}>;
auto s_run_all = s;
s_run_all.log_level_ = 0;
const float cur = fmha_fwd_splitkv_<traits_, traits2_, {F_arch.tag}>(s_run_all, a);
if(cur >= 0) {{
const double tflops = static_cast<double>(t.perf_flop) / 1.0e9 / static_cast<double>(cur);
const double gbps = static_cast<double>(t.perf_num_byte) / 1.0e6 / static_cast<double>(cur);
run_all_results.emplace_back(
run_all_result_t{{
cur,
tflops,
gbps,
fmha_fwd_splitkv_get_name_<traits_, {F_arch.tag}>(),
}});
}}
}}
}}
"""
@dataclass
class FmhaFwdSplitKVApiTrait:
@@ -335,11 +443,12 @@ class FmhaFwdSplitKVApiTrait:
pagedkv: str
sink: str # sink or not
bn1comb: int # tile size along v head_dim of combine kernel
runtime_name: str
@property
def name(self) -> str:
return (
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-"
f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn1}-{self.bk1}-{self.bk0max}-"
+ f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-"
+ f"{self.dvpad}-{self.pagedkv}-{self.sink}"
)
@@ -552,16 +661,23 @@ class FmhaFwdSplitKVApiPool:
@property
def api(self) -> str:
per_arch = str()
per_arch_list_all = str()
per_arch_run_all_all = str()
per_arch_exec_all = str()
for i_arch, (arch, pool_by_arch) in enumerate(self.pool.items()):
per_dtypes = str()
per_dtypes_list = str()
per_dtypes_run_all = str()
per_dtypes_exec = str()
for i_dtype, (dtype, pool_by_dtype) in enumerate(pool_by_arch.items()):
per_hdim_case = str()
per_hdim_case_list = str()
per_hdim_case_run_all = str()
per_hdim_case_exec = str()
for i_hdim, (hdim, pool_by_hdim) in enumerate(pool_by_dtype.items()):
inners = str()
inners_list = str()
inners_run_all = str()
inners_exec = str()
for i_trait, trait in enumerate(pool_by_hdim):
inners += FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format(
F_if=if_(i_trait),
common = dict(
F_arch=arch,
F_mode=MODE_MAP[trait.mode],
F_vlayout=LAYOUT_MAP[trait.vlayout],
@@ -591,27 +707,76 @@ class FmhaFwdSplitKVApiPool:
F_bk0max=trait.bk0max,
F_hdim=hdim,
F_dtype=FWD_DTYPE_MAP[dtype],
F_kernel_name=trait.name,
F_runtime_name=trait.runtime_name,
F_bn1comb=trait.bn1comb,
)
per_hdim_case += FMHA_FWD_API_PER_HDIM_CASE.format(
inners_list += FMHA_FWD_SPLITKV_API_INNER_DISPATCH_LIST.format(
**common,
)
inners_run_all += FMHA_FWD_SPLITKV_API_INNER_DISPATCH_RUN_ALL.format(
**common,
)
inners_exec += FMHA_FWD_SPLITKV_API_INNER_DISPATCH_EXEC.format(
F_if=if_(i_trait),
**common,
)
per_hdim_case_list += FMHA_FWD_API_PER_HDIM_CASE.format(
F_if=if_(i_hdim),
F_hdim=hdim,
F_hdim_v=hdim,
F_inner_dispatch=indent(inners),
F_inner_dispatch=indent(inners_list),
)
per_dtypes += FMHA_FWD_API_PER_DTYPE.format(
F_if=if_(i_dtype), F_dtype=dtype, F_hdim_case=indent(per_hdim_case)
per_hdim_case_exec += FMHA_FWD_API_PER_HDIM_CASE.format(
F_if=if_(i_hdim),
F_hdim=hdim,
F_hdim_v=hdim,
F_inner_dispatch=indent(inners_exec),
)
per_hdim_case_run_all += FMHA_FWD_API_PER_HDIM_CASE.format(
F_if=if_(i_hdim),
F_hdim=hdim,
F_hdim_v=hdim,
F_inner_dispatch=indent(inners_run_all),
)
per_dtypes_list += FMHA_FWD_API_PER_DTYPE.format(
F_if=if_(i_dtype), F_dtype=dtype, F_hdim_case=indent(per_hdim_case_list)
)
per_arch += FMHA_FWD_API_PER_ARCH.format(
per_dtypes_run_all += FMHA_FWD_API_PER_DTYPE.format(
F_if=if_(i_dtype),
F_dtype=dtype,
F_hdim_case=indent(per_hdim_case_run_all),
)
per_dtypes_exec += FMHA_FWD_API_PER_DTYPE.format(
F_if=if_(i_dtype), F_dtype=dtype, F_hdim_case=indent(per_hdim_case_exec)
)
per_arch_list = FMHA_FWD_API_PER_ARCH.format(
F_if=if_(i_arch),
F_arch=arch,
F_dtype_case=indent(per_dtypes),
F_dtype_case=indent(per_dtypes_list),
)
if not per_arch:
per_arch_exec = FMHA_FWD_API_PER_ARCH.format(
F_if=if_(i_arch),
F_arch=arch,
F_dtype_case=indent(per_dtypes_exec),
)
per_arch_run_all = FMHA_FWD_API_PER_ARCH.format(
F_if=if_(i_arch),
F_arch=arch,
F_dtype_case=indent(per_dtypes_run_all),
)
per_arch_list_all += per_arch_list
per_arch_run_all_all += per_arch_run_all
per_arch_exec_all += per_arch_exec
if not per_arch_list_all:
# empty string we add some ignore to suppress warning in api
per_arch = "(void)t; (void)s; (void)a;"
per_arch_list_all = "(void)t; (void)s; (void)a;"
per_arch_run_all_all = "(void)t; (void)s; (void)a;"
per_arch_exec_all = "(void)t; (void)s; (void)a;"
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_SPLITKV_API.format(
F_dispatch=indent(per_arch)
F_dispatch_list=indent(per_arch_list_all),
F_dispatch_run_all=indent(per_arch_run_all_all),
F_dispatch_exec=indent(per_arch_exec_all),
)
@@ -823,7 +988,24 @@ class KernelComponentFactoryGfx9(KernelComponentFactoryBase):
"32" : FmhaFwdTileSize( 32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"64" : FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"96" : FmhaFwdTileSize( 64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"128": FmhaFwdTileSize( 64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"128": [
FmhaFwdTileSize( 16, 32, 32, 128, 16, 128, 1, 1, 1, 1, 1, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaFwdTileSize( 16, 64, 32, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaFwdTileSize( 16, 128, 32, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaFwdTileSize( 16, 256, 32, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaFwdTileSize( 32, 32, 32, 128, 16, 128, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaFwdTileSize( 32, 64, 32, 128, 32, 128, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaFwdTileSize( 32, 128, 32, 128, 32, 128, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaFwdTileSize( 32, 256, 32, 128, 32, 128, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaFwdTileSize( 64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaFwdTileSize( 64, 256, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaFwdTileSize( 64, 32, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaFwdTileSize(128, 32, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
FmhaFwdTileSize(128, 256, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
],
# "160" : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
"256": FmhaFwdTileSize( 64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
} # fmt: skip
@@ -886,77 +1068,80 @@ def get_fwd_splitkv_blobs(
continue
# for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
tile = d[hdim_str]
hdim = int(hdim_str)
for pipeline in factory.get_pipelines(dtype, hdim, mask_impl):
if mode == "group":
if pipeline.F_spad != "t" or pipeline.F_skpad != "t":
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue
# logits_soft_cap is only allowed if no bias
if not (
(pipeline.F_logits == "t" and pipeline.F_bias == "no")
or pipeline.F_logits == "f"
):
continue
k = Kernel(
F_arch=factory.arch,
F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
F_mode=mode,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl,
)
if kernel_filter != "":
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
if optdim_list != [-1]:
if hdim not in optdim_list:
continue
# Flash attention integration
if receipt == 2:
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_bias in ["no", "alibi"]
cond &= pipeline.F_squant == "f"
cond &= pipeline.F_sink == "f"
if not cond:
continue
# PyTorch integration
elif receipt == 4:
cond = dtype in ["fp16, bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_bias in ["no", "bias"]
cond &= pipeline.F_squant == "f"
cond &= mode == "batch"
cond &= pipeline.F_sink == "f"
if not cond:
continue
# Aiter(mha_varlen_fwd) integration
elif receipt == 200:
cond = dtype in ["fp16", "bf16"]
cond &= mode == "group"
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_squant == "f"
if not cond:
continue
# aiter::mha_fwd_splikv C++ api integration
elif receipt == 600:
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_squant == "f"
if not cond:
tiles = d[hdim_str]
if not isinstance(tiles, list):
tiles = [tiles]
for tile in tiles:
for pipeline in factory.get_pipelines(dtype, hdim, mask_impl):
if mode == "group":
if pipeline.F_spad != "t" or pipeline.F_skpad != "t":
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue
# logits_soft_cap is only allowed if no bias
if not (
(pipeline.F_logits == "t" and pipeline.F_bias == "no")
or pipeline.F_logits == "f"
):
continue
k = Kernel(
F_arch=factory.arch,
F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
F_mode=mode,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl,
)
if kernel_filter != "":
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
if optdim_list != [-1]:
if hdim not in optdim_list:
continue
# Flash attention integration
if receipt == 2:
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_bias in ["no", "alibi"]
cond &= pipeline.F_squant == "f"
cond &= pipeline.F_sink == "f"
if not cond:
continue
# PyTorch integration
elif receipt == 4:
cond = dtype in ["fp16, bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_bias in ["no", "bias"]
cond &= pipeline.F_squant == "f"
cond &= mode == "batch"
cond &= pipeline.F_sink == "f"
if not cond:
continue
# Aiter(mha_varlen_fwd) integration
elif receipt == 200:
cond = dtype in ["fp16", "bf16"]
cond &= mode == "group"
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_squant == "f"
if not cond:
continue
# aiter::mha_fwd_splikv C++ api integration
elif receipt == 600:
cond = dtype in ["fp16", "bf16"]
cond &= pipeline.F_vlayout == "row"
cond &= pipeline.F_squant == "f"
if not cond:
continue
# fp32 only
if receipt == 800 or receipt == 801:
cond = dtype == "fp32"
if not cond:
continue
# fp32 only
if receipt == 800 or receipt == 801:
cond = dtype == "fp32"
if not cond:
continue
gen.append(k)
gen.append(k)
return gen
@@ -1096,6 +1281,7 @@ def write_blobs(
dpad=kernel.F_pipeline.F_dpad,
dvpad=kernel.F_pipeline.F_dvpad,
bn1comb=combine_kernel.F_tile.F_bn1,
runtime_name=kernel.name,
)
)
write_fwd_splitkv_api(api_pool, output_dir)

View File

@@ -77,6 +77,15 @@ auto create_args(int argc, char* argv[])
.insert("vlayout", "r", "r for row-major(seqlen*hdim), c for col-major(hdim*seqlen)")
.insert("lse", "0", "0 not store lse, 1 store lse")
.insert("kname", "0", "if set to 1 will print kernel name")
.insert("kernel_filter",
"",
"optional substring filter for selecting a specific prebuilt fmha_fwd kernel")
.insert("list_kernels",
"0",
"if set to 1, list compatible fmha_fwd kernels and do not launch")
.insert("run_all_kernels",
"0",
"if set to 1, run all compatible fmha_fwd kernels and report the best time")
.insert("init",
"uf",
"init method:\n ui or 0 - uniform random int\n ni - normalized random int"
@@ -146,19 +155,22 @@ auto run(const ck_tile::ArgParser& arg_parser)
bool is_v_rowmajor = arg_parser.get_str("vlayout") == "r";
bool lse = arg_parser.get_bool("lse");
ck_tile::index_t page_block_size = arg_parser.get_int("page_block_size");
bool use_cache_batch_idx = arg_parser.get_bool("cache_batch_idx");
std::string bias_str = arg_parser.get_str("bias");
std::string qscale_str = arg_parser.get_str("qscale");
float p_drop = arg_parser.get_float("p_drop");
uint64_t drop_seed = arg_parser.get_uint64("drop_seed");
uint64_t drop_offset = arg_parser.get_uint64("drop_offset");
bool drop_prefs = arg_parser.get_bool("drop_prefs");
std::string mask_str = arg_parser.get_str("mask");
bool is_rotary_interleaved = arg_parser.get_bool("rotary_interleaved");
ck_tile::index_t num_splits = arg_parser.get_int("num_splits");
std::string init_method = arg_parser.get_str("init");
uint32_t seed = arg_parser.get_uint32("seed");
int init_sink_value = arg_parser.get_int("init_sink");
bool use_cache_batch_idx = arg_parser.get_bool("cache_batch_idx");
std::string bias_str = arg_parser.get_str("bias");
std::string qscale_str = arg_parser.get_str("qscale");
float p_drop = arg_parser.get_float("p_drop");
uint64_t drop_seed = arg_parser.get_uint64("drop_seed");
uint64_t drop_offset = arg_parser.get_uint64("drop_offset");
bool drop_prefs = arg_parser.get_bool("drop_prefs");
std::string mask_str = arg_parser.get_str("mask");
bool is_rotary_interleaved = arg_parser.get_bool("rotary_interleaved");
ck_tile::index_t num_splits = arg_parser.get_int("num_splits");
std::string kernel_filter = arg_parser.get_str("kernel_filter");
bool list_kernels = arg_parser.get_bool("list_kernels");
bool run_all_kernels = arg_parser.get_bool("run_all_kernels");
std::string init_method = arg_parser.get_str("init");
uint32_t seed = arg_parser.get_uint32("seed");
int init_sink_value = arg_parser.get_int("init_sink");
ck_tile::stream_config stream_config{nullptr,
true,
@@ -202,6 +214,9 @@ auto run(const ck_tile::ArgParser& arg_parser)
qscale_str,
is_rotary_interleaved,
num_splits,
kernel_filter,
list_kernels,
run_all_kernels,
init_method,
seed,
do_validation,

View File

@@ -1598,6 +1598,11 @@ struct fmha_fwd_traits
quant_scale_enum qscale_type;
bool skip_min_seqlen_q = false;
bool has_sink = false;
std::string kernel_filter = "";
bool list_kernels = false;
bool run_all_kernels = false;
std::size_t perf_flop = 0;
std::size_t perf_num_byte = 0;
// TODO: padding check is inside this api
};
float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&);
@@ -1637,6 +1642,11 @@ struct fmha_fwd_splitkv_traits
bool has_lse;
bool do_fp8_static_quant;
bool has_sink = false;
std::string kernel_filter = "";
bool list_kernels = false;
bool run_all_kernels = false;
std::size_t perf_flop = 0;
std::size_t perf_num_byte = 0;
// TODO: padding check is inside this api
};
float fmha_fwd_splitkv(fmha_fwd_splitkv_traits,

View File

@@ -203,6 +203,9 @@ fwd_result fmha_fwd_run(mode_enum mode,
std::string qscale_str,
bool is_rotary_interleaved,
ck_tile::index_t num_splits,
const std::string& kernel_filter,
bool list_kernels,
bool run_all_kernels,
std::string init_method,
uint32_t seed,
int do_validation,
@@ -974,11 +977,27 @@ fwd_result fmha_fwd_run(mode_enum mode,
{
traits.has_dropout = (p_drop > 0.0f);
traits.qscale_type = qscale.type;
traits.kernel_filter = kernel_filter;
traits.list_kernels = list_kernels;
traits.run_all_kernels = run_all_kernels;
traits.perf_flop = flop;
traits.perf_num_byte = num_byte;
}
else if constexpr(std::is_same_v<fmha_fwd_pagedkv_traits,
std::decay_t<decltype(traits)>>)
{
traits.use_pagedkv = (0 < page_block_size);
traits.do_fp8_static_quant = (data_type == "fp8");
}
else if constexpr(std::is_same_v<fmha_fwd_splitkv_traits,
std::decay_t<decltype(traits)>>)
{
traits.kernel_filter = kernel_filter;
traits.list_kernels = list_kernels;
traits.run_all_kernels = run_all_kernels;
traits.perf_flop = flop;
traits.perf_num_byte = num_byte;
traits.do_fp8_static_quant = (data_type == "fp8");
}
}
};
@@ -1363,9 +1382,19 @@ fwd_result fmha_fwd_run(mode_enum mode,
fmha_fwd_args fmha_args;
init_args(fmha_args);
if(fmha_traits.list_kernels || fmha_traits.run_all_kernels)
{
std::cout << std::endl;
}
return fmha_fwd(fmha_traits, fmha_args, sc);
};
const float fwd_ave_time = run_fwd(stream_config);
if(list_kernels)
{
std::cout << std::flush << std::endl;
return fwd_result::success;
}
if(fwd_ave_time < 0.0f)
{
std::cout << ", not supported yet" << std::flush << std::endl;