* Remove some duplicate code in fmha_fwd_appendkv_kernel.hpp

* Simplify two templated operator calls by having the templated types deduced automatically

* Simplify two GemmPipeline calls

* Fix GemmPipelineAgBgCrCompV4::GetName

* Refactor use of ArgParser in CK tile GEMM examples

* Update args in README.md to match the implementation in create_args

* Remove some unnecessary include statements

* Rename two variables

* Factor out common code

* Factor out do_verify

* Add and use type aliases for memory operation integral constants

* In gemm_basic.cpp, use kPadM, kPadN, kPadK, and kBlockPerCu from GemmConfig

---------

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>

[ROCm/composable_kernel commit: 28a97865f5]
This commit is contained in:
SamiAario-AMD
2025-08-13 11:12:08 +03:00
committed by GitHub
parent e3e8de3477
commit 8a32077ccd
9 changed files with 142 additions and 178 deletions

View File

@@ -647,44 +647,25 @@ struct FmhaFwdAppendKVKernel
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kN0>{}),
{0, i_n0});
if constexpr(kApplyRoPE)
{
FmhaPipeline{}(q_dram_window,
k_dram_window,
i_page_block_k,
k_page_block_navigator,
knew_dram_window,
v_dram_window,
i_page_block_v,
v_page_block_navigator,
vnew_dram_window,
q_rotary_cos_dram_window,
q_rotary_sin_dram_window,
knew_rotary_cos_dram_window,
knew_rotary_sin_dram_window,
kargs.rotary_dim,
kargs.seqlen_q <= i_m0,
skip_append_kv);
}
else
{
FmhaPipeline{}(q_dram_window,
k_dram_window,
i_page_block_k,
k_page_block_navigator,
knew_dram_window,
v_dram_window,
i_page_block_v,
v_page_block_navigator,
vnew_dram_window,
q_rotary_cos_dram_window,
q_rotary_sin_dram_window,
knew_rotary_cos_dram_window,
knew_rotary_sin_dram_window,
0, // rotary_dim not used
kargs.seqlen_q <= i_m0,
skip_append_kv);
}
// If kApplyRoPe is false, we set the rotary_dim to 0
auto rotary_dim = kApplyRoPE ? kargs.rotary_dim : 0;
FmhaPipeline{}(q_dram_window,
k_dram_window,
i_page_block_k,
k_page_block_navigator,
knew_dram_window,
v_dram_window,
i_page_block_v,
v_page_block_navigator,
vnew_dram_window,
q_rotary_cos_dram_window,
q_rotary_sin_dram_window,
knew_rotary_cos_dram_window,
knew_rotary_sin_dram_window,
rotary_dim,
kargs.seqlen_q <= i_m0,
skip_append_kv);
}
};

View File

@@ -943,17 +943,15 @@ struct UniversalGemmKernel
const auto& bs_block_window = gemm_tile_windows.at(I1);
const auto& ds_block_window = gemm_tile_windows.at(I2);
const auto& c_block_tile = GemmPipeline{}.template operator()(
as_block_window[I0], bs_block_window[I0], num_loop, smem_ptr_0);
const auto& c_block_tile =
GemmPipeline{}(as_block_window[I0], bs_block_window[I0], num_loop, smem_ptr_0);
if(UseDefaultScheduler || (get_warp_id() == 0))
{
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}.template
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(ds_block_window)>(
c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
}
}
@@ -1001,15 +999,13 @@ struct UniversalGemmKernel
const auto& bs_block_window = gemm_tile_windows.at(I1);
const auto& ds_block_window = gemm_tile_windows.at(I2);
const auto& c_block_tile = GemmPipeline{}.template operator()(
const auto& c_block_tile = GemmPipeline{}(
as_block_window[I0], bs_block_window[I0], num_loop, smem_ptr_0, smem_ptr_1);
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}.template
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(ds_block_window)>(
c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
}
// Non-persistent kernel entry point

View File

@@ -149,7 +149,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "pipeline_AgBgCrCompV3",
return concat('_', "pipeline_AgBgCrCompV4",
concat('x', MPerBlock, NPerBlock, KPerBlock, BlockSize),
concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()),
concat('x', kPadM, kPadN, kPadK));