[rocm-libraries] ROCm/rocm-libraries#5863 (commit 31d9247)

[CK_TILE] Separate PermuteN epilogue from CShuffle epilogue
 into standalone file (#5863)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

The PermuteN epilogue was previously embedded within
cshuffle_epilogue.hpp, despite having fundamentally different behaviour.
Coupling these two independent strategies in one file introduced
unnecessary complexity, SFINAE guards, and a dual operator() overload
selected at compile time via TiledMMAPermuteN_ template parameter.

This PR separates PermuteN into its own standalone
file(pertmuten_epilogue.hpp), simplifying both implementations and
making the codebase easier to maintain and extend independently.

## Technical Details

**New file: permuten_epilogue.hpp:**
contains PermuteNEpilogueProblem and PermuteNEpilogue, extracted from
the permuteN code path in cshuffle_epilogue.hpp.

**Cleanup of cshuffle_epilogue.hpp:**

- Removed the TiledMMAPermuteN_ template parameter from
[CShuffleEpilogueProblem]
- Removed the SFINAE-guarded permuteN operator() overload
- Removed the EnablePermuateN_ SFINAE alias
- CShuffle now only contains CShuffle logic; EightWave support
(independent feature) is retained

**Consumer migration :**
All consumer files now use compile-time epilogue selection via
[std::conditional_t]

`using GemmEpilogue = std::conditional_t<
    TiledMMAPermuteN,
    PermuteNEpilogue<PermuteNEpilogueProblem<...>>,
    CShuffleEpilogue<CShuffleEpilogueProblem<...>>>;`

**Files modified:**

- flatmm_basic.cpp, moe_flatmm.cpp, a16w4_moe_flatmm.cpp,
mixed_prec_flatmm.cpp, mx_flatmm_instance.hpp — flatmm examples
- run_gemm_quant_example.inc — block-scale GEMM example
- gemm_weight_preshuffle_invoker.hpp — weight preshuffle invoker
- test_gemm_quant_fixtures.hpp, test_gemm_persistent_async_input.cpp,
test_gemm_pipeline_util.hpp — test utilities
- universal_gemm_invoker.hpp — universal GEMM invoker
- epilogue.hpp — add header updated to include permuten_epilogue.hpp

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
msaffari-amd
2026-04-14 20:23:26 +00:00
committed by assistant-librarian[bot]
parent 5d2fce819d
commit 5348b577ed
14 changed files with 728 additions and 333 deletions

View File

@@ -188,27 +188,45 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
using CodegenFlatmmPipeline =
ck_tile::FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDatatype,
AccDataType,
CDataType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
FlatmmConfig::M_Warp,
FlatmmConfig::N_Warp,
FlatmmConfig::M_Warp_Tile,
FlatmmConfig::N_Warp_Tile,
FlatmmConfig::K_Warp_Tile,
CodegenPipelineProblem::TransposeC,
FlatmmConfig::NumWaveGroups,
false,
1,
FlatmmConfig::TiledMMAPermuteN>>;
using GemmEpilogue = std::conditional_t<
FlatmmConfig::TiledMMAPermuteN,
ck_tile::PermuteNEpilogue<
ck_tile::PermuteNEpilogueProblem<ADataType,
BDataType,
DsDatatype,
AccDataType,
CDataType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
FlatmmConfig::M_Warp,
FlatmmConfig::N_Warp,
FlatmmConfig::M_Warp_Tile,
FlatmmConfig::N_Warp_Tile,
FlatmmConfig::K_Warp_Tile,
CodegenPipelineProblem::TransposeC,
false,
1>>,
ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDatatype,
AccDataType,
CDataType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
FlatmmConfig::M_Warp,
FlatmmConfig::N_Warp,
FlatmmConfig::M_Warp_Tile,
FlatmmConfig::N_Warp_Tile,
FlatmmConfig::K_Warp_Tile,
CodegenPipelineProblem::TransposeC,
FlatmmConfig::NumWaveGroups>>>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
@@ -230,6 +248,7 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
<< "Shape: " << CodegenFlatmmShape::GetName() << "\n"
<< "problem: " << CodegenPipelineProblem::GetName() << "\n"
<< "pipeline: " << CodegenFlatmmPipeline::GetName() << "\n"
<< "epilogue: " << GemmEpilogue::GetName() << "\n"
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;