mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
[rocm-libraries] ROCm/rocm-libraries#5863 (commit 31d9247)
[CK_TILE] Separate PermuteN epilogue from CShuffle epilogue into standalone file (#5863) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Motivation The PermuteN epilogue was previously embedded within cshuffle_epilogue.hpp, despite having fundamentally different behaviour. Coupling these two independent strategies in one file introduced unnecessary complexity, SFINAE guards, and a dual operator() overload selected at compile time via TiledMMAPermuteN_ template parameter. This PR separates PermuteN into its own standalone file(pertmuten_epilogue.hpp), simplifying both implementations and making the codebase easier to maintain and extend independently. ## Technical Details **New file: permuten_epilogue.hpp:** contains PermuteNEpilogueProblem and PermuteNEpilogue, extracted from the permuteN code path in cshuffle_epilogue.hpp. **Cleanup of cshuffle_epilogue.hpp:** - Removed the TiledMMAPermuteN_ template parameter from [CShuffleEpilogueProblem] - Removed the SFINAE-guarded permuteN operator() overload - Removed the EnablePermuateN_ SFINAE alias - CShuffle now only contains CShuffle logic; EightWave support (independent feature) is retained **Consumer migration :** All consumer files now use compile-time epilogue selection via [std::conditional_t] `using GemmEpilogue = std::conditional_t< TiledMMAPermuteN, PermuteNEpilogue<PermuteNEpilogueProblem<...>>, CShuffleEpilogue<CShuffleEpilogueProblem<...>>>;` **Files modified:** - flatmm_basic.cpp, moe_flatmm.cpp, a16w4_moe_flatmm.cpp, mixed_prec_flatmm.cpp, mx_flatmm_instance.hpp — flatmm examples - run_gemm_quant_example.inc — block-scale GEMM example - gemm_weight_preshuffle_invoker.hpp — weight preshuffle invoker - test_gemm_quant_fixtures.hpp, test_gemm_persistent_async_input.cpp, test_gemm_pipeline_util.hpp — test utilities - universal_gemm_invoker.hpp — universal GEMM invoker - epilogue.hpp — add header updated to include permuten_epilogue.hpp ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
5d2fce819d
commit
5348b577ed
@@ -188,27 +188,45 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
using CodegenFlatmmPipeline =
|
||||
ck_tile::FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false,
|
||||
1,
|
||||
FlatmmConfig::TiledMMAPermuteN>>;
|
||||
using GemmEpilogue = std::conditional_t<
|
||||
FlatmmConfig::TiledMMAPermuteN,
|
||||
ck_tile::PermuteNEpilogue<
|
||||
ck_tile::PermuteNEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
false,
|
||||
1>>,
|
||||
ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
FlatmmConfig::NumWaveGroups>>>;
|
||||
|
||||
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
|
||||
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
|
||||
@@ -230,6 +248,7 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
<< "Shape: " << CodegenFlatmmShape::GetName() << "\n"
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << "\n"
|
||||
<< "pipeline: " << CodegenFlatmmPipeline::GetName() << "\n"
|
||||
<< "epilogue: " << GemmEpilogue::GetName() << "\n"
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
|
||||
Reference in New Issue
Block a user