[CK_TILE] Separate PermuteN epilogue from CShuffle epilogue into standalone file (#5863)

## Motivation

The PermuteN epilogue was previously embedded within
cshuffle_epilogue.hpp, despite having fundamentally different behaviour.
Coupling these two independent strategies in one file introduced
unnecessary complexity, SFINAE guards, and a dual operator() overload
selected at compile time via TiledMMAPermuteN_ template parameter.

This PR separates PermuteN into its own standalone
file(pertmuten_epilogue.hpp), simplifying both implementations and
making the codebase easier to maintain and extend independently.

## Technical Details

**New file: permuten_epilogue.hpp:** 
contains PermuteNEpilogueProblem and PermuteNEpilogue, extracted from
the permuteN code path in cshuffle_epilogue.hpp.

**Cleanup of cshuffle_epilogue.hpp:**

- Removed the TiledMMAPermuteN_ template parameter from
[CShuffleEpilogueProblem]
- Removed the SFINAE-guarded permuteN operator() overload
- Removed the EnablePermuateN_ SFINAE alias
- CShuffle now only contains CShuffle logic; EightWave support
(independent feature) is retained

**Consumer migration :**
All consumer files now use compile-time epilogue selection via
[std::conditional_t]

`using GemmEpilogue = std::conditional_t<
    TiledMMAPermuteN,
    PermuteNEpilogue<PermuteNEpilogueProblem<...>>,
    CShuffleEpilogue<CShuffleEpilogueProblem<...>>>;`

**Files modified:**

- flatmm_basic.cpp, moe_flatmm.cpp, a16w4_moe_flatmm.cpp,
mixed_prec_flatmm.cpp, mx_flatmm_instance.hpp — flatmm examples
- run_gemm_quant_example.inc — block-scale GEMM example
- gemm_weight_preshuffle_invoker.hpp — weight preshuffle invoker
- test_gemm_quant_fixtures.hpp, test_gemm_persistent_async_input.cpp,
test_gemm_pipeline_util.hpp — test utilities
- universal_gemm_invoker.hpp — universal GEMM invoker
- epilogue.hpp — add header updated to include permuten_epilogue.hpp



## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
This commit is contained in:
msaffari-amd
2026-04-14 22:22:18 +02:00
committed by GitHub
parent 5f2517da31
commit 6072031cf4
14 changed files with 728 additions and 333 deletions

View File

@@ -221,7 +221,6 @@ class TestCkTileGemmPipeline : public ::testing::Test
1, /*kNumWaveGroups_*/
false, /*FixedVectorSize_*/
1, /*VectorSizeC_*/
false, /*TiledMMAPermuteN_*/
1, /*BlockedXDLN_PerWarp_*/
DoubleSmemBuffer /*DoubleSmemBuffer*/>>;

View File

@@ -937,29 +937,49 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGem
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>,
ck_tile::WPQuantBPipelineAgBgCrV2<PipelineProblem>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
ADataType,
std::conditional_t<std::is_same_v<BDataType, ck_tile::pk_fp4_t>,
ADataType,
BDataType>,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
Base::M_Warp,
Base::N_Warp,
Base::M_Warp_Tile,
Base::N_Warp_Tile,
Base::K_Warp_Tile,
false, // transpose_c
1,
false,
1,
TiledMMAPermuteN>>;
// clang-format off
using BTypeForEpilogue =
std::conditional_t<std::is_same_v<BDataType, ck_tile::pk_fp4_t>, ADataType, BDataType>;
// clang-format on
using GemmEpilogue = std::conditional_t<
TiledMMAPermuteN,
ck_tile::PermuteNEpilogue<
ck_tile::PermuteNEpilogueProblem<ADataType,
BTypeForEpilogue,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
Base::M_Warp,
Base::N_Warp,
Base::M_Warp_Tile,
Base::N_Warp_Tile,
Base::K_Warp_Tile,
false, // transpose_c
false,
1>>,
ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BTypeForEpilogue,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
Base::M_Warp,
Base::N_Warp,
Base::M_Warp_Tile,
Base::N_Warp_Tile,
Base::K_Warp_Tile,
false>>>; // transpose_c
using Kernel = ck_tile::QuantGemmKernel<TilePartitioner,
GemmPipeline,
@@ -1281,27 +1301,44 @@ class TestCkTileGemmABQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGe
ck_tile::WPABQuantBPipelineAgBgCrV2<PipelineProblem>,
ck_tile::ABQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<typename PipelineProblem::ComputeDataType,
typename PipelineProblem::ComputeDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
Base::M_Warp,
Base::N_Warp,
Base::M_Warp_Tile,
Base::N_Warp_Tile,
Base::K_Warp_Tile,
transpose_c,
1,
false,
1,
TiledMMAPermuteN>>;
using GemmEpilogue = std::conditional_t<
TiledMMAPermuteN,
ck_tile::PermuteNEpilogue<
ck_tile::PermuteNEpilogueProblem<typename PipelineProblem::ComputeDataType,
typename PipelineProblem::ComputeDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
Base::M_Warp,
Base::N_Warp,
Base::M_Warp_Tile,
Base::N_Warp_Tile,
Base::K_Warp_Tile,
transpose_c,
false,
1>>,
ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<typename PipelineProblem::ComputeDataType,
typename PipelineProblem::ComputeDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
Base::M_Warp,
Base::N_Warp,
Base::M_Warp_Tile,
Base::N_Warp_Tile,
Base::K_Warp_Tile,
transpose_c>>>;
using Kernel = ck_tile::QuantGemmKernel<TilePartitioner,
GemmPipeline,

View File

@@ -159,12 +159,11 @@ class TestGemmPersistentAsyncInput : public ::testing::Test
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC,
1, // kNumWaveGroups_
false, // FixedVectorSize_
1, // VectorSizeC_
false, // TiledMMAPermuteN_
1, // BlockedXDLN_PerWarp_
DoubleSmemBuffer>>;
1, /*kNumWaveGroups_*/
false, /*FixedVectorSize_*/
1, /*VectorSizeC_*/
1, /*BlockedXDLN_PerWarp_*/
DoubleSmemBuffer /*DoubleSmemBuffer*/>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;