From 27a99edef9c4840035a0f7852b4fbe9b15c72357 Mon Sep 17 00:00:00 2001 From: Mohsen Saffari Date: Mon, 2 Mar 2026 14:50:19 +0000 Subject: [PATCH] cover more instances and change test code to do benchmarking --- .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 108 +++++- .../01_fmha/codegen/ops/fmha_fwd_splitkv.py | 360 +++++++++++++----- example/ck_tile/01_fmha/example_fmha_fwd.cpp | 41 +- example/ck_tile/01_fmha/fmha_fwd.hpp | 10 + example/ck_tile/01_fmha/fmha_fwd_runner.hpp | 29 ++ 5 files changed, 427 insertions(+), 121 deletions(-) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index f9301878c4..570af5fb93 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -146,6 +146,9 @@ FMHA_FWD_API_HEADER = """ // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n // auto generated by generate.py #include +#include +#include +#include #include @@ -182,7 +185,16 @@ unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seq FMHA_FWD_API_FUNC_TEMPLATE = """ namespace {{ float {F_func_name}([[maybe_unused]] fmha_fwd_traits t, [[maybe_unused]] fmha_fwd_args a, [[maybe_unused]] const ck_tile::stream_config& s) {{ + struct run_all_result_t {{ + float time_ms; + double tflops; + double gbps; + const char* runtime_name; + }}; + float r = -1; + [[maybe_unused]] bool output_started = false; + [[maybe_unused]] std::vector run_all_results; [[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate @@ -198,6 +210,19 @@ float {F_func_name}([[maybe_unused]] fmha_fwd_traits t, [[maybe_unused]] fmha_fw [[maybe_unused]] const std::string device_name = ck_tile::get_device_name(); {F_dispatch} + if(t.run_all_kernels && !run_all_results.empty()) {{ + std::sort(run_all_results.begin(), run_all_results.end(), + [](const auto& lhs, const auto& rhs) {{ return lhs.time_ms < rhs.time_ms; }}); + printf("\\n"); + for(const auto& result : run_all_results) {{ + printf("%s, %.6f ms, %.2f TFlops, %.2f GB/s\\n", + result.runtime_name, + static_cast(result.time_ms), + result.tflops, + result.gbps); + }} + r = run_all_results.front().time_ms; + }} return r; }} }} // namespace @@ -238,10 +263,51 @@ FMHA_FWD_API_PER_HDIM_CASE = """{F_if}(t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hd }} """ -FMHA_FWD_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.skip_min_seqlen_q == {F_skip}) &&(t.has_sink == {F_sink}) && +FMHA_FWD_API_INNER_DISPATCH = """if(t.list_kernels && (t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.skip_min_seqlen_q == {F_skip}) &&(t.has_sink == {F_sink}) && ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ - using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}, {F_sink}>; - return fmha_fwd_(s, a); + constexpr const char* kernel_name = "{F_kernel_name}"; + constexpr const char* runtime_name = "{F_runtime_name}"; + const bool kernel_match = t.kernel_filter.empty() || + (std::string(kernel_name).find(t.kernel_filter) != std::string::npos) || + (std::string(runtime_name).find(t.kernel_filter) != std::string::npos); + if(kernel_match) {{ + if(!output_started) {{ + printf("\\n"); + output_started = true; + }} + printf("%s | %s\\n", kernel_name, runtime_name); + }} +}} +if(t.run_all_kernels && (t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.skip_min_seqlen_q == {F_skip}) &&(t.has_sink == {F_sink}) && + ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ + constexpr const char* kernel_name = "{F_kernel_name}"; + constexpr const char* runtime_name = "{F_runtime_name}"; + const bool kernel_match = t.kernel_filter.empty() || + (std::string(kernel_name).find(t.kernel_filter) != std::string::npos) || + (std::string(runtime_name).find(t.kernel_filter) != std::string::npos); + if(kernel_match) {{ + using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}, {F_sink}>; + auto s_run_all = s; + s_run_all.log_level_ = 0; + const float cur = fmha_fwd_(s_run_all, a); + if(cur >= 0) {{ + const double tflops = static_cast(t.perf_flop) / 1.0e9 / static_cast(cur); + const double gbps = static_cast(t.perf_num_byte) / 1.0e6 / static_cast(cur); + run_all_results.emplace_back(run_all_result_t{{cur, tflops, gbps, runtime_name}}); + }} + }} +}} +{F_if}((!t.list_kernels) && (!t.run_all_kernels) && (t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.qscale_type == {F_qscale_check}) && (t.skip_min_seqlen_q == {F_skip}) &&(t.has_sink == {F_sink}) && + ({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{ + constexpr const char* kernel_name = "{F_kernel_name}"; + constexpr const char* runtime_name = "{F_runtime_name}"; + const bool kernel_match = t.kernel_filter.empty() || + (std::string(kernel_name).find(t.kernel_filter) != std::string::npos) || + (std::string(runtime_name).find(t.kernel_filter) != std::string::npos); + if(kernel_match) {{ + using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_qscale}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}, {F_sink}>; + return fmha_fwd_(s, a); + }} }} """ @@ -288,12 +354,13 @@ class FmhaFwdApiTrait: skip: str tr_load: str sink: str + runtime_name: str constraint: CppConstraint @property def name(self) -> str: return ( - f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" + f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn1}-{self.bk1}-{self.bk0max}-" + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.qscale}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}-{self.sink}" ) @@ -577,6 +644,8 @@ class FmhaFwdApiPool: F_bk0max=trait.bk0max, F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype], + F_kernel_name=trait.name, + F_runtime_name=trait.runtime_name, ) per_hdim_case += FMHA_FWD_API_PER_HDIM_CASE.format( F_if=if_(i_hdim), @@ -743,6 +812,7 @@ class FmhaFwdKernel: skip=self.F_pipeline.F_skip, tr_load=self.F_pipeline.F_trload, sink=self.F_pipeline.F_sink, + runtime_name=self.name, constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint, ) @@ -837,24 +907,9 @@ class CompatibilityRuleFactoryGfx9(CompatibilityRuleFactory): problem_ctx: ProblemContext, kernel_ctx: KernelContext ) -> bool: if problem_ctx.dtype != "fp32": - # TODO: update if >=gfx11 archs get qr_async and qr_async_trload support - if kernel_ctx.pipeline.tag in cls._AVAILABLE_PIPELINES and ( - ( - (problem_ctx.hdim, problem_ctx.hdim_v) == (128, 128) - and kernel_ctx.tile.F_bn0 != 128 - ) - or ( - (problem_ctx.hdim, problem_ctx.hdim_v) != (128, 128) - and kernel_ctx.tile.F_bm0 != 128 - ) - or ( - (problem_ctx.hdim, problem_ctx.hdim_v) == (128, 128) - and kernel_ctx.pipeline.tag != "qr_async" - and kernel_ctx.tile.F_bk0 == 64 - ) + if kernel_ctx.pipeline.tag in {"qr", "qs"} and ( + kernel_ctx.tile.F_bk0 >= problem_ctx.hdim ): - # non qr_async_trload only support km0=128 tile size when hdim is not 128 - # non qr_async only support kn0=128 tile size when hdim is 128 return False return True @@ -946,9 +1001,20 @@ class KernelComponentFactoryGfx9(CompatibilityRuleFactoryGfx9): ( 80, 96) : [FmhaFwdTileSize(128, 128, 16, 96, 32, 80, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], ( 96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], (128, 128) : [FmhaFwdTileSize( 16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), + FmhaFwdTileSize( 16, 64, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), + FmhaFwdTileSize( 16, 128, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1), FmhaFwdTileSize( 32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), + FmhaFwdTileSize( 32, 64, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1), + FmhaFwdTileSize( 32, 64, 32, 128, 32, 128, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize( 32, 128, 32, 128, 32, 128, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize( 32, 256, 32, 128, 32, 128, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), FmhaFwdTileSize( 64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 16, -1, CppConstraint('get_num_blocks(64) <= num_cus')), + FmhaFwdTileSize( 64, 256, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 16, -1), + FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize( 64, 32, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize(128, 32, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + FmhaFwdTileSize(128, 256, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], # (160, 160) : [FmhaFwdTileSize(128, 128 , 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], (192, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 9105900fc7..bc5898d677 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -253,6 +253,9 @@ std::string fmha_fwd_splitkv_combine_get_name_() FMHA_FWD_SPLITKV_API_FILENAME = "fmha_fwd_splitkv_api.cpp" FMHA_FWD_SPLITKV_API = """ #include +#include +#include +#include template float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) @@ -270,17 +273,76 @@ float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a }} float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const ck_tile::stream_config& s) {{ + struct run_all_result_t {{ + float time_ms; + double tflops; + double gbps; + std::string kernel_name; + }}; + float r = -1; + [[maybe_unused]] bool output_started = false; + [[maybe_unused]] std::vector run_all_results; [[maybe_unused]] const std::string device_name = ck_tile::get_device_name(); -{F_dispatch} + if(t.list_kernels) {{ +{F_dispatch_list} + return r; + }} + + if(t.run_all_kernels) {{ +{F_dispatch_run_all} + if(!run_all_results.empty()) {{ + std::sort(run_all_results.begin(), run_all_results.end(), + [](const auto& lhs, const auto& rhs) {{ return lhs.time_ms < rhs.time_ms; }}); + std::cout << "\\n"; + for(const auto& result : run_all_results) {{ + printf("%s, %.6f ms, %.2f TFlops, %.2f GB/s\\n", + result.kernel_name.c_str(), + static_cast(result.time_ms), + result.tflops, + result.gbps); + }} + r = run_all_results.front().time_ms; + }} + return r; + }} + +{F_dispatch_exec} return r; }} """ -FMHA_FWD_SPLITKV_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) && - ((a.block_table_ptr != nullptr) == {F_pagedkv}) && (t.has_sink == {F_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ +FMHA_FWD_SPLITKV_API_INNER_DISPATCH_LIST = """if((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) && + ((a.block_table_ptr != nullptr) == {F_pagedkv}) && (t.has_sink == {F_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && + (t.kernel_filter.empty() || + (std::string("{F_kernel_name}").find(t.kernel_filter) != std::string::npos) || + (std::string("{F_runtime_name}").find(t.kernel_filter) != std::string::npos))) {{ + if(!output_started) {{ + std::cout << "\\n"; + output_started = true; + }} + using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv},{F_sink}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + if (t.has_lse) {{ + if constexpr (std::is_same_v<{F_dtype}, FmhaFwdFp8>) {{ + // unsupported(fp8+lse) + }} else {{ + std::cout << fmha_fwd_splitkv_get_name_() + << "\\n"; + }} + }} else {{ + std::cout << fmha_fwd_splitkv_get_name_() + << "\\n"; + }} +}} +""" + +FMHA_FWD_SPLITKV_API_INNER_DISPATCH_EXEC = """{F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) && + ((a.block_table_ptr != nullptr) == {F_pagedkv}) && (t.has_sink == {F_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && + (t.kernel_filter.empty() || + (std::string("{F_kernel_name}").find(t.kernel_filter) != std::string::npos) || + (std::string("{F_runtime_name}").find(t.kernel_filter) != std::string::npos))) {{ using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv},{F_sink}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; // get combine kernel tile sizes @@ -307,6 +369,52 @@ FMHA_FWD_SPLITKV_API_INNER_DISPATCH = """{F_if}((t.is_group_mode == {F_mode}) && }} """ +FMHA_FWD_SPLITKV_API_INNER_DISPATCH_RUN_ALL = """if((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) && + ((a.block_table_ptr != nullptr) == {F_pagedkv}) && (t.has_sink == {F_sink}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && + (t.kernel_filter.empty() || + (std::string("{F_kernel_name}").find(t.kernel_filter) != std::string::npos) || + (std::string("{F_runtime_name}").find(t.kernel_filter) != std::string::npos))) {{ + using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv},{F_sink}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>; + if (t.has_lse) {{ + if constexpr (std::is_same_v<{F_dtype}, FmhaFwdFp8>) {{ + // unsupported(fp8+lse) + }} else {{ + using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bn1comb}, true, {F_squant}, {F_spad}, {F_dvpad}>; + auto s_run_all = s; + s_run_all.log_level_ = 0; + const float cur = fmha_fwd_splitkv_(s_run_all, a); + if(cur >= 0) {{ + const double tflops = static_cast(t.perf_flop) / 1.0e9 / static_cast(cur); + const double gbps = static_cast(t.perf_num_byte) / 1.0e6 / static_cast(cur); + run_all_results.emplace_back( + run_all_result_t{{ + cur, + tflops, + gbps, + fmha_fwd_splitkv_get_name_(), + }}); + }} + }} + }} else {{ + using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bn1comb}, false, {F_squant}, {F_spad}, {F_dvpad}>; + auto s_run_all = s; + s_run_all.log_level_ = 0; + const float cur = fmha_fwd_splitkv_(s_run_all, a); + if(cur >= 0) {{ + const double tflops = static_cast(t.perf_flop) / 1.0e9 / static_cast(cur); + const double gbps = static_cast(t.perf_num_byte) / 1.0e6 / static_cast(cur); + run_all_results.emplace_back( + run_all_result_t{{ + cur, + tflops, + gbps, + fmha_fwd_splitkv_get_name_(), + }}); + }} + }} +}} +""" + @dataclass class FmhaFwdSplitKVApiTrait: @@ -335,11 +443,12 @@ class FmhaFwdSplitKVApiTrait: pagedkv: str sink: str # sink or not bn1comb: int # tile size along v head_dim of combine kernel + runtime_name: str @property def name(self) -> str: return ( - f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-" + f"{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn1}-{self.bk1}-{self.bk0max}-" + f"{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-" + f"{self.dvpad}-{self.pagedkv}-{self.sink}" ) @@ -552,16 +661,23 @@ class FmhaFwdSplitKVApiPool: @property def api(self) -> str: - per_arch = str() + per_arch_list_all = str() + per_arch_run_all_all = str() + per_arch_exec_all = str() for i_arch, (arch, pool_by_arch) in enumerate(self.pool.items()): - per_dtypes = str() + per_dtypes_list = str() + per_dtypes_run_all = str() + per_dtypes_exec = str() for i_dtype, (dtype, pool_by_dtype) in enumerate(pool_by_arch.items()): - per_hdim_case = str() + per_hdim_case_list = str() + per_hdim_case_run_all = str() + per_hdim_case_exec = str() for i_hdim, (hdim, pool_by_hdim) in enumerate(pool_by_dtype.items()): - inners = str() + inners_list = str() + inners_run_all = str() + inners_exec = str() for i_trait, trait in enumerate(pool_by_hdim): - inners += FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format( - F_if=if_(i_trait), + common = dict( F_arch=arch, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout], @@ -591,27 +707,76 @@ class FmhaFwdSplitKVApiPool: F_bk0max=trait.bk0max, F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype], + F_kernel_name=trait.name, + F_runtime_name=trait.runtime_name, F_bn1comb=trait.bn1comb, ) - per_hdim_case += FMHA_FWD_API_PER_HDIM_CASE.format( + inners_list += FMHA_FWD_SPLITKV_API_INNER_DISPATCH_LIST.format( + **common, + ) + inners_run_all += FMHA_FWD_SPLITKV_API_INNER_DISPATCH_RUN_ALL.format( + **common, + ) + inners_exec += FMHA_FWD_SPLITKV_API_INNER_DISPATCH_EXEC.format( + F_if=if_(i_trait), + **common, + ) + per_hdim_case_list += FMHA_FWD_API_PER_HDIM_CASE.format( F_if=if_(i_hdim), F_hdim=hdim, F_hdim_v=hdim, - F_inner_dispatch=indent(inners), + F_inner_dispatch=indent(inners_list), ) - per_dtypes += FMHA_FWD_API_PER_DTYPE.format( - F_if=if_(i_dtype), F_dtype=dtype, F_hdim_case=indent(per_hdim_case) + per_hdim_case_exec += FMHA_FWD_API_PER_HDIM_CASE.format( + F_if=if_(i_hdim), + F_hdim=hdim, + F_hdim_v=hdim, + F_inner_dispatch=indent(inners_exec), + ) + per_hdim_case_run_all += FMHA_FWD_API_PER_HDIM_CASE.format( + F_if=if_(i_hdim), + F_hdim=hdim, + F_hdim_v=hdim, + F_inner_dispatch=indent(inners_run_all), + ) + per_dtypes_list += FMHA_FWD_API_PER_DTYPE.format( + F_if=if_(i_dtype), F_dtype=dtype, F_hdim_case=indent(per_hdim_case_list) ) - per_arch += FMHA_FWD_API_PER_ARCH.format( + per_dtypes_run_all += FMHA_FWD_API_PER_DTYPE.format( + F_if=if_(i_dtype), + F_dtype=dtype, + F_hdim_case=indent(per_hdim_case_run_all), + ) + per_dtypes_exec += FMHA_FWD_API_PER_DTYPE.format( + F_if=if_(i_dtype), F_dtype=dtype, F_hdim_case=indent(per_hdim_case_exec) + ) + per_arch_list = FMHA_FWD_API_PER_ARCH.format( F_if=if_(i_arch), F_arch=arch, - F_dtype_case=indent(per_dtypes), + F_dtype_case=indent(per_dtypes_list), ) - if not per_arch: + per_arch_exec = FMHA_FWD_API_PER_ARCH.format( + F_if=if_(i_arch), + F_arch=arch, + F_dtype_case=indent(per_dtypes_exec), + ) + per_arch_run_all = FMHA_FWD_API_PER_ARCH.format( + F_if=if_(i_arch), + F_arch=arch, + F_dtype_case=indent(per_dtypes_run_all), + ) + per_arch_list_all += per_arch_list + per_arch_run_all_all += per_arch_run_all + per_arch_exec_all += per_arch_exec + if not per_arch_list_all: # empty string we add some ignore to suppress warning in api - per_arch = "(void)t; (void)s; (void)a;" + per_arch_list_all = "(void)t; (void)s; (void)a;" + per_arch_run_all_all = "(void)t; (void)s; (void)a;" + per_arch_exec_all = "(void)t; (void)s; (void)a;" return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_SPLITKV_API.format( - F_dispatch=indent(per_arch) + F_dispatch_list=indent(per_arch_list_all), + F_dispatch_run_all=indent(per_arch_run_all_all), + F_dispatch_exec=indent(per_arch_exec_all), ) @@ -823,7 +988,24 @@ class KernelComponentFactoryGfx9(KernelComponentFactoryBase): "32" : FmhaFwdTileSize( 32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), "64" : FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), "96" : FmhaFwdTileSize( 64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), - "128": FmhaFwdTileSize( 64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + "128": [ + FmhaFwdTileSize( 16, 32, 32, 128, 16, 128, 1, 1, 1, 1, 1, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize( 16, 64, 32, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize( 16, 128, 32, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize( 16, 256, 32, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize( 32, 32, 32, 128, 16, 128, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize( 32, 64, 32, 128, 32, 128, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize( 32, 128, 32, 128, 32, 128, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize( 32, 256, 32, 128, 32, 128, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize( 64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize( 64, 256, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize( 64, 32, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + FmhaFwdTileSize(128, 32, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + FmhaFwdTileSize(128, 256, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + ], # "160" : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), "256": FmhaFwdTileSize( 64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), } # fmt: skip @@ -886,77 +1068,80 @@ def get_fwd_splitkv_blobs( continue # for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): - tile = d[hdim_str] hdim = int(hdim_str) - for pipeline in factory.get_pipelines(dtype, hdim, mask_impl): - if mode == "group": - if pipeline.F_spad != "t" or pipeline.F_skpad != "t": - # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not - continue - # logits_soft_cap is only allowed if no bias - if not ( - (pipeline.F_logits == "t" and pipeline.F_bias == "no") - or pipeline.F_logits == "f" - ): - continue - k = Kernel( - F_arch=factory.arch, - F_idx=0, - F_hdim=hdim, - F_dtype=dtype, - F_mode=mode, - F_tile=tile, - F_pipeline=pipeline, - mask_impl=mask_impl, - ) - if kernel_filter != "": - if not fnmatch.fnmatch(k.name, kernel_filter): - continue - if optdim_list != [-1]: - if hdim not in optdim_list: - continue - # Flash attention integration - if receipt == 2: - cond = dtype in ["fp16", "bf16"] - cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_bias in ["no", "alibi"] - cond &= pipeline.F_squant == "f" - cond &= pipeline.F_sink == "f" - if not cond: - continue - # PyTorch integration - elif receipt == 4: - cond = dtype in ["fp16, bf16"] - cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_bias in ["no", "bias"] - cond &= pipeline.F_squant == "f" - cond &= mode == "batch" - cond &= pipeline.F_sink == "f" - if not cond: - continue - # Aiter(mha_varlen_fwd) integration - elif receipt == 200: - cond = dtype in ["fp16", "bf16"] - cond &= mode == "group" - cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_squant == "f" - if not cond: - continue - # aiter::mha_fwd_splikv C++ api integration - elif receipt == 600: - cond = dtype in ["fp16", "bf16"] - cond &= pipeline.F_vlayout == "row" - cond &= pipeline.F_squant == "f" - if not cond: + tiles = d[hdim_str] + if not isinstance(tiles, list): + tiles = [tiles] + for tile in tiles: + for pipeline in factory.get_pipelines(dtype, hdim, mask_impl): + if mode == "group": + if pipeline.F_spad != "t" or pipeline.F_skpad != "t": + # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not + continue + # logits_soft_cap is only allowed if no bias + if not ( + (pipeline.F_logits == "t" and pipeline.F_bias == "no") + or pipeline.F_logits == "f" + ): continue + k = Kernel( + F_arch=factory.arch, + F_idx=0, + F_hdim=hdim, + F_dtype=dtype, + F_mode=mode, + F_tile=tile, + F_pipeline=pipeline, + mask_impl=mask_impl, + ) + if kernel_filter != "": + if not fnmatch.fnmatch(k.name, kernel_filter): + continue + if optdim_list != [-1]: + if hdim not in optdim_list: + continue + # Flash attention integration + if receipt == 2: + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_bias in ["no", "alibi"] + cond &= pipeline.F_squant == "f" + cond &= pipeline.F_sink == "f" + if not cond: + continue + # PyTorch integration + elif receipt == 4: + cond = dtype in ["fp16, bf16"] + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_bias in ["no", "bias"] + cond &= pipeline.F_squant == "f" + cond &= mode == "batch" + cond &= pipeline.F_sink == "f" + if not cond: + continue + # Aiter(mha_varlen_fwd) integration + elif receipt == 200: + cond = dtype in ["fp16", "bf16"] + cond &= mode == "group" + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_squant == "f" + if not cond: + continue + # aiter::mha_fwd_splikv C++ api integration + elif receipt == 600: + cond = dtype in ["fp16", "bf16"] + cond &= pipeline.F_vlayout == "row" + cond &= pipeline.F_squant == "f" + if not cond: + continue - # fp32 only - if receipt == 800 or receipt == 801: - cond = dtype == "fp32" - if not cond: - continue + # fp32 only + if receipt == 800 or receipt == 801: + cond = dtype == "fp32" + if not cond: + continue - gen.append(k) + gen.append(k) return gen @@ -1096,6 +1281,7 @@ def write_blobs( dpad=kernel.F_pipeline.F_dpad, dvpad=kernel.F_pipeline.F_dvpad, bn1comb=combine_kernel.F_tile.F_bn1, + runtime_name=kernel.name, ) ) write_fwd_splitkv_api(api_pool, output_dir) diff --git a/example/ck_tile/01_fmha/example_fmha_fwd.cpp b/example/ck_tile/01_fmha/example_fmha_fwd.cpp index f5ad6b2bc5..d78bcae8dc 100644 --- a/example/ck_tile/01_fmha/example_fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/example_fmha_fwd.cpp @@ -77,6 +77,15 @@ auto create_args(int argc, char* argv[]) .insert("vlayout", "r", "r for row-major(seqlen*hdim), c for col-major(hdim*seqlen)") .insert("lse", "0", "0 not store lse, 1 store lse") .insert("kname", "0", "if set to 1 will print kernel name") + .insert("kernel_filter", + "", + "optional substring filter for selecting a specific prebuilt fmha_fwd kernel") + .insert("list_kernels", + "0", + "if set to 1, list compatible fmha_fwd kernels and do not launch") + .insert("run_all_kernels", + "0", + "if set to 1, run all compatible fmha_fwd kernels and report the best time") .insert("init", "uf", "init method:\n ui or 0 - uniform random int\n ni - normalized random int" @@ -146,19 +155,22 @@ auto run(const ck_tile::ArgParser& arg_parser) bool is_v_rowmajor = arg_parser.get_str("vlayout") == "r"; bool lse = arg_parser.get_bool("lse"); ck_tile::index_t page_block_size = arg_parser.get_int("page_block_size"); - bool use_cache_batch_idx = arg_parser.get_bool("cache_batch_idx"); - std::string bias_str = arg_parser.get_str("bias"); - std::string qscale_str = arg_parser.get_str("qscale"); - float p_drop = arg_parser.get_float("p_drop"); - uint64_t drop_seed = arg_parser.get_uint64("drop_seed"); - uint64_t drop_offset = arg_parser.get_uint64("drop_offset"); - bool drop_prefs = arg_parser.get_bool("drop_prefs"); - std::string mask_str = arg_parser.get_str("mask"); - bool is_rotary_interleaved = arg_parser.get_bool("rotary_interleaved"); - ck_tile::index_t num_splits = arg_parser.get_int("num_splits"); - std::string init_method = arg_parser.get_str("init"); - uint32_t seed = arg_parser.get_uint32("seed"); - int init_sink_value = arg_parser.get_int("init_sink"); + bool use_cache_batch_idx = arg_parser.get_bool("cache_batch_idx"); + std::string bias_str = arg_parser.get_str("bias"); + std::string qscale_str = arg_parser.get_str("qscale"); + float p_drop = arg_parser.get_float("p_drop"); + uint64_t drop_seed = arg_parser.get_uint64("drop_seed"); + uint64_t drop_offset = arg_parser.get_uint64("drop_offset"); + bool drop_prefs = arg_parser.get_bool("drop_prefs"); + std::string mask_str = arg_parser.get_str("mask"); + bool is_rotary_interleaved = arg_parser.get_bool("rotary_interleaved"); + ck_tile::index_t num_splits = arg_parser.get_int("num_splits"); + std::string kernel_filter = arg_parser.get_str("kernel_filter"); + bool list_kernels = arg_parser.get_bool("list_kernels"); + bool run_all_kernels = arg_parser.get_bool("run_all_kernels"); + std::string init_method = arg_parser.get_str("init"); + uint32_t seed = arg_parser.get_uint32("seed"); + int init_sink_value = arg_parser.get_int("init_sink"); ck_tile::stream_config stream_config{nullptr, true, @@ -202,6 +214,9 @@ auto run(const ck_tile::ArgParser& arg_parser) qscale_str, is_rotary_interleaved, num_splits, + kernel_filter, + list_kernels, + run_all_kernels, init_method, seed, do_validation, diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index ee404010ef..01ecfa3acd 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -1598,6 +1598,11 @@ struct fmha_fwd_traits quant_scale_enum qscale_type; bool skip_min_seqlen_q = false; bool has_sink = false; + std::string kernel_filter = ""; + bool list_kernels = false; + bool run_all_kernels = false; + std::size_t perf_flop = 0; + std::size_t perf_num_byte = 0; // TODO: padding check is inside this api }; float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&); @@ -1637,6 +1642,11 @@ struct fmha_fwd_splitkv_traits bool has_lse; bool do_fp8_static_quant; bool has_sink = false; + std::string kernel_filter = ""; + bool list_kernels = false; + bool run_all_kernels = false; + std::size_t perf_flop = 0; + std::size_t perf_num_byte = 0; // TODO: padding check is inside this api }; float fmha_fwd_splitkv(fmha_fwd_splitkv_traits, diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index 1227724d40..8853be63ed 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -203,6 +203,9 @@ fwd_result fmha_fwd_run(mode_enum mode, std::string qscale_str, bool is_rotary_interleaved, ck_tile::index_t num_splits, + const std::string& kernel_filter, + bool list_kernels, + bool run_all_kernels, std::string init_method, uint32_t seed, int do_validation, @@ -974,11 +977,27 @@ fwd_result fmha_fwd_run(mode_enum mode, { traits.has_dropout = (p_drop > 0.0f); traits.qscale_type = qscale.type; + traits.kernel_filter = kernel_filter; + traits.list_kernels = list_kernels; + traits.run_all_kernels = run_all_kernels; + traits.perf_flop = flop; + traits.perf_num_byte = num_byte; } else if constexpr(std::is_same_v>) { traits.use_pagedkv = (0 < page_block_size); + traits.do_fp8_static_quant = (data_type == "fp8"); + } + else if constexpr(std::is_same_v>) + { + traits.kernel_filter = kernel_filter; + traits.list_kernels = list_kernels; + traits.run_all_kernels = run_all_kernels; + traits.perf_flop = flop; + traits.perf_num_byte = num_byte; + traits.do_fp8_static_quant = (data_type == "fp8"); } } }; @@ -1363,9 +1382,19 @@ fwd_result fmha_fwd_run(mode_enum mode, fmha_fwd_args fmha_args; init_args(fmha_args); + if(fmha_traits.list_kernels || fmha_traits.run_all_kernels) + { + std::cout << std::endl; + } + return fmha_fwd(fmha_traits, fmha_args, sc); }; const float fwd_ave_time = run_fwd(stream_config); + if(list_kernels) + { + std::cout << std::flush << std::endl; + return fwd_result::success; + } if(fwd_ave_time < 0.0f) { std::cout << ", not supported yet" << std::flush << std::endl;