mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 12:59:49 +00:00
[CK_TILE] Separate PermuteN epilogue from CShuffle epilogue into standalone file (#5863)
## Motivation
The PermuteN epilogue was previously embedded within
cshuffle_epilogue.hpp, despite having fundamentally different behaviour.
Coupling these two independent strategies in one file introduced
unnecessary complexity, SFINAE guards, and a dual operator() overload
selected at compile time via TiledMMAPermuteN_ template parameter.
This PR separates PermuteN into its own standalone
file(pertmuten_epilogue.hpp), simplifying both implementations and
making the codebase easier to maintain and extend independently.
## Technical Details
**New file: permuten_epilogue.hpp:**
contains PermuteNEpilogueProblem and PermuteNEpilogue, extracted from
the permuteN code path in cshuffle_epilogue.hpp.
**Cleanup of cshuffle_epilogue.hpp:**
- Removed the TiledMMAPermuteN_ template parameter from
[CShuffleEpilogueProblem]
- Removed the SFINAE-guarded permuteN operator() overload
- Removed the EnablePermuateN_ SFINAE alias
- CShuffle now only contains CShuffle logic; EightWave support
(independent feature) is retained
**Consumer migration :**
All consumer files now use compile-time epilogue selection via
[std::conditional_t]
`using GemmEpilogue = std::conditional_t<
TiledMMAPermuteN,
PermuteNEpilogue<PermuteNEpilogueProblem<...>>,
CShuffleEpilogue<CShuffleEpilogueProblem<...>>>;`
**Files modified:**
- flatmm_basic.cpp, moe_flatmm.cpp, a16w4_moe_flatmm.cpp,
mixed_prec_flatmm.cpp, mx_flatmm_instance.hpp — flatmm examples
- run_gemm_quant_example.inc — block-scale GEMM example
- gemm_weight_preshuffle_invoker.hpp — weight preshuffle invoker
- test_gemm_quant_fixtures.hpp, test_gemm_persistent_async_input.cpp,
test_gemm_pipeline_util.hpp — test utilities
- universal_gemm_invoker.hpp — universal GEMM invoker
- epilogue.hpp — add header updated to include permuten_epilogue.hpp
## Submission Checklist
- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
---------
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
This commit is contained in:
@@ -221,7 +221,6 @@ class TestCkTileGemmPipeline : public ::testing::Test
|
||||
1, /*kNumWaveGroups_*/
|
||||
false, /*FixedVectorSize_*/
|
||||
1, /*VectorSizeC_*/
|
||||
false, /*TiledMMAPermuteN_*/
|
||||
1, /*BlockedXDLN_PerWarp_*/
|
||||
DoubleSmemBuffer /*DoubleSmemBuffer*/>>;
|
||||
|
||||
|
||||
@@ -937,29 +937,49 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>,
|
||||
ck_tile::WPQuantBPipelineAgBgCrV2<PipelineProblem>>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
ADataType,
|
||||
std::conditional_t<std::is_same_v<BDataType, ck_tile::pk_fp4_t>,
|
||||
ADataType,
|
||||
BDataType>,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
Base::M_Warp,
|
||||
Base::N_Warp,
|
||||
Base::M_Warp_Tile,
|
||||
Base::N_Warp_Tile,
|
||||
Base::K_Warp_Tile,
|
||||
false, // transpose_c
|
||||
1,
|
||||
false,
|
||||
1,
|
||||
TiledMMAPermuteN>>;
|
||||
// clang-format off
|
||||
using BTypeForEpilogue =
|
||||
std::conditional_t<std::is_same_v<BDataType, ck_tile::pk_fp4_t>, ADataType, BDataType>;
|
||||
// clang-format on
|
||||
|
||||
using GemmEpilogue = std::conditional_t<
|
||||
TiledMMAPermuteN,
|
||||
ck_tile::PermuteNEpilogue<
|
||||
ck_tile::PermuteNEpilogueProblem<ADataType,
|
||||
BTypeForEpilogue,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
Base::M_Warp,
|
||||
Base::N_Warp,
|
||||
Base::M_Warp_Tile,
|
||||
Base::N_Warp_Tile,
|
||||
Base::K_Warp_Tile,
|
||||
false, // transpose_c
|
||||
false,
|
||||
1>>,
|
||||
ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BTypeForEpilogue,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
Base::M_Warp,
|
||||
Base::N_Warp,
|
||||
Base::M_Warp_Tile,
|
||||
Base::N_Warp_Tile,
|
||||
Base::K_Warp_Tile,
|
||||
false>>>; // transpose_c
|
||||
|
||||
using Kernel = ck_tile::QuantGemmKernel<TilePartitioner,
|
||||
GemmPipeline,
|
||||
@@ -1281,27 +1301,44 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGe
|
||||
ck_tile::WPABQuantBPipelineAgBgCrV2<PipelineProblem>,
|
||||
ck_tile::ABQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<typename PipelineProblem::ComputeDataType,
|
||||
typename PipelineProblem::ComputeDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
Base::M_Warp,
|
||||
Base::N_Warp,
|
||||
Base::M_Warp_Tile,
|
||||
Base::N_Warp_Tile,
|
||||
Base::K_Warp_Tile,
|
||||
transpose_c,
|
||||
1,
|
||||
false,
|
||||
1,
|
||||
TiledMMAPermuteN>>;
|
||||
using GemmEpilogue = std::conditional_t<
|
||||
TiledMMAPermuteN,
|
||||
ck_tile::PermuteNEpilogue<
|
||||
ck_tile::PermuteNEpilogueProblem<typename PipelineProblem::ComputeDataType,
|
||||
typename PipelineProblem::ComputeDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
Base::M_Warp,
|
||||
Base::N_Warp,
|
||||
Base::M_Warp_Tile,
|
||||
Base::N_Warp_Tile,
|
||||
Base::K_Warp_Tile,
|
||||
transpose_c,
|
||||
false,
|
||||
1>>,
|
||||
ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<typename PipelineProblem::ComputeDataType,
|
||||
typename PipelineProblem::ComputeDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
Base::M_Warp,
|
||||
Base::N_Warp,
|
||||
Base::M_Warp_Tile,
|
||||
Base::N_Warp_Tile,
|
||||
Base::K_Warp_Tile,
|
||||
transpose_c>>>;
|
||||
|
||||
using Kernel = ck_tile::QuantGemmKernel<TilePartitioner,
|
||||
GemmPipeline,
|
||||
|
||||
@@ -159,12 +159,11 @@ class TestGemmPersistentAsyncInput : public ::testing::Test
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
1, // kNumWaveGroups_
|
||||
false, // FixedVectorSize_
|
||||
1, // VectorSizeC_
|
||||
false, // TiledMMAPermuteN_
|
||||
1, // BlockedXDLN_PerWarp_
|
||||
DoubleSmemBuffer>>;
|
||||
1, /*kNumWaveGroups_*/
|
||||
false, /*FixedVectorSize_*/
|
||||
1, /*VectorSizeC_*/
|
||||
1, /*BlockedXDLN_PerWarp_*/
|
||||
DoubleSmemBuffer /*DoubleSmemBuffer*/>>;
|
||||
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user