mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 20:09:25 +00:00
Cleanups (#2631)
* Remove some duplicate code in fmha_fwd_appendkv_kernel.hpp
* Simplify two templated operator calls by having the templated types deduced automatically
* Simplify two GemmPipeline calls
* Fix GemmPipelineAgBgCrCompV4::GetName
* Refactor use of ArgParser in CK tile GEMM examples
* Update args in README.md to match the implementation in create_args
* Remove some unnecessary include statements
* Rename two variables
* Factor out common code
* Factor out do_verify
* Add and use type aliases for memory operation integral constants
* In gemm_basic.cpp, use kPadM, kPadN, kPadK, and kBlockPerCu from GemmConfig
---------
Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
[ROCm/composable_kernel commit: 28a97865f5]
This commit is contained in:
@@ -647,44 +647,25 @@ struct FmhaFwdAppendKVKernel
|
||||
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kN0>{}),
|
||||
{0, i_n0});
|
||||
|
||||
if constexpr(kApplyRoPE)
|
||||
{
|
||||
FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
i_page_block_k,
|
||||
k_page_block_navigator,
|
||||
knew_dram_window,
|
||||
v_dram_window,
|
||||
i_page_block_v,
|
||||
v_page_block_navigator,
|
||||
vnew_dram_window,
|
||||
q_rotary_cos_dram_window,
|
||||
q_rotary_sin_dram_window,
|
||||
knew_rotary_cos_dram_window,
|
||||
knew_rotary_sin_dram_window,
|
||||
kargs.rotary_dim,
|
||||
kargs.seqlen_q <= i_m0,
|
||||
skip_append_kv);
|
||||
}
|
||||
else
|
||||
{
|
||||
FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
i_page_block_k,
|
||||
k_page_block_navigator,
|
||||
knew_dram_window,
|
||||
v_dram_window,
|
||||
i_page_block_v,
|
||||
v_page_block_navigator,
|
||||
vnew_dram_window,
|
||||
q_rotary_cos_dram_window,
|
||||
q_rotary_sin_dram_window,
|
||||
knew_rotary_cos_dram_window,
|
||||
knew_rotary_sin_dram_window,
|
||||
0, // rotary_dim not used
|
||||
kargs.seqlen_q <= i_m0,
|
||||
skip_append_kv);
|
||||
}
|
||||
// If kApplyRoPe is false, we set the rotary_dim to 0
|
||||
auto rotary_dim = kApplyRoPE ? kargs.rotary_dim : 0;
|
||||
|
||||
FmhaPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
i_page_block_k,
|
||||
k_page_block_navigator,
|
||||
knew_dram_window,
|
||||
v_dram_window,
|
||||
i_page_block_v,
|
||||
v_page_block_navigator,
|
||||
vnew_dram_window,
|
||||
q_rotary_cos_dram_window,
|
||||
q_rotary_sin_dram_window,
|
||||
knew_rotary_cos_dram_window,
|
||||
knew_rotary_sin_dram_window,
|
||||
rotary_dim,
|
||||
kargs.seqlen_q <= i_m0,
|
||||
skip_append_kv);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -943,17 +943,15 @@ struct UniversalGemmKernel
|
||||
const auto& bs_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& ds_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
as_block_window[I0], bs_block_window[I0], num_loop, smem_ptr_0);
|
||||
const auto& c_block_tile =
|
||||
GemmPipeline{}(as_block_window[I0], bs_block_window[I0], num_loop, smem_ptr_0);
|
||||
|
||||
if(UseDefaultScheduler || (get_warp_id() == 0))
|
||||
{
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
|
||||
EpiloguePipeline{}.template
|
||||
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(ds_block_window)>(
|
||||
c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1001,15 +999,13 @@ struct UniversalGemmKernel
|
||||
const auto& bs_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& ds_block_window = gemm_tile_windows.at(I2);
|
||||
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(
|
||||
const auto& c_block_tile = GemmPipeline{}(
|
||||
as_block_window[I0], bs_block_window[I0], num_loop, smem_ptr_0, smem_ptr_1);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
|
||||
EpiloguePipeline{}.template
|
||||
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(ds_block_window)>(
|
||||
c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
|
||||
}
|
||||
|
||||
// Non-persistent kernel entry point
|
||||
|
||||
@@ -149,7 +149,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "pipeline_AgBgCrCompV3",
|
||||
return concat('_', "pipeline_AgBgCrCompV4",
|
||||
concat('x', MPerBlock, NPerBlock, KPerBlock, BlockSize),
|
||||
concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()),
|
||||
concat('x', kPadM, kPadN, kPadK));
|
||||
|
||||
Reference in New Issue
Block a user