mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 19:40:04 +00:00
[CK_TILE] Separate PermuteN epilogue from CShuffle epilogue into standalone file (#5863)
## Motivation
The PermuteN epilogue was previously embedded within
cshuffle_epilogue.hpp, despite having fundamentally different behaviour.
Coupling these two independent strategies in one file introduced
unnecessary complexity, SFINAE guards, and a dual operator() overload
selected at compile time via TiledMMAPermuteN_ template parameter.
This PR separates PermuteN into its own standalone
file(pertmuten_epilogue.hpp), simplifying both implementations and
making the codebase easier to maintain and extend independently.
## Technical Details
**New file: permuten_epilogue.hpp:**
contains PermuteNEpilogueProblem and PermuteNEpilogue, extracted from
the permuteN code path in cshuffle_epilogue.hpp.
**Cleanup of cshuffle_epilogue.hpp:**
- Removed the TiledMMAPermuteN_ template parameter from
[CShuffleEpilogueProblem]
- Removed the SFINAE-guarded permuteN operator() overload
- Removed the EnablePermuateN_ SFINAE alias
- CShuffle now only contains CShuffle logic; EightWave support
(independent feature) is retained
**Consumer migration :**
All consumer files now use compile-time epilogue selection via
[std::conditional_t]
`using GemmEpilogue = std::conditional_t<
TiledMMAPermuteN,
PermuteNEpilogue<PermuteNEpilogueProblem<...>>,
CShuffleEpilogue<CShuffleEpilogueProblem<...>>>;`
**Files modified:**
- flatmm_basic.cpp, moe_flatmm.cpp, a16w4_moe_flatmm.cpp,
mixed_prec_flatmm.cpp, mx_flatmm_instance.hpp — flatmm examples
- run_gemm_quant_example.inc — block-scale GEMM example
- gemm_weight_preshuffle_invoker.hpp — weight preshuffle invoker
- test_gemm_quant_fixtures.hpp, test_gemm_persistent_async_input.cpp,
test_gemm_pipeline_util.hpp — test utilities
- universal_gemm_invoker.hpp — universal GEMM invoker
- epilogue.hpp — add header updated to include permuten_epilogue.hpp
## Submission Checklist
- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
---------
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
This commit is contained in:
@@ -58,27 +58,45 @@ struct WeightPreshuffleInvoker
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
GemmConfig::NumWaveGroups,
|
||||
false,
|
||||
1,
|
||||
GemmConfig::TiledMMAPermuteN>>;
|
||||
using GemmEpilogue = std::conditional_t<
|
||||
GemmConfig::TiledMMAPermuteN,
|
||||
ck_tile::PermuteNEpilogue<
|
||||
ck_tile::PermuteNEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
false,
|
||||
1>>,
|
||||
ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
GemmConfig::NumWaveGroups>>>;
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
|
||||
@@ -84,7 +84,6 @@ struct UniversalInvoker
|
||||
GemmConfig::NumWaveGroups,
|
||||
false, /*FixedVectorSize_*/
|
||||
1, /*VectorSizeC_*/
|
||||
false, /*TiledMMAPermuteN_*/
|
||||
1, /*BlockedXDLN_PerWarp_*/
|
||||
GemmConfig::DoubleSmemBuffer /*DoubleSmemBuffer*/>>;
|
||||
|
||||
@@ -228,7 +227,6 @@ struct UniversalInvoker
|
||||
GemmConfig::NumWaveGroups,
|
||||
false, /*FixedVectorSize_*/
|
||||
1, /*VectorSizeC_*/
|
||||
false, /*TiledMMAPermuteN_*/
|
||||
1, /*BlockedXDLN_PerWarp_*/
|
||||
GemmConfig::DoubleSmemBuffer>>;
|
||||
|
||||
|
||||
@@ -188,27 +188,45 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
using CodegenFlatmmPipeline =
|
||||
ck_tile::FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false,
|
||||
1,
|
||||
FlatmmConfig::TiledMMAPermuteN>>;
|
||||
using GemmEpilogue = std::conditional_t<
|
||||
FlatmmConfig::TiledMMAPermuteN,
|
||||
ck_tile::PermuteNEpilogue<
|
||||
ck_tile::PermuteNEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
false,
|
||||
1>>,
|
||||
ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
FlatmmConfig::NumWaveGroups>>>;
|
||||
|
||||
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
|
||||
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
|
||||
@@ -230,6 +248,7 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
<< "Shape: " << CodegenFlatmmShape::GetName() << "\n"
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << "\n"
|
||||
<< "pipeline: " << CodegenFlatmmPipeline::GetName() << "\n"
|
||||
<< "epilogue: " << GemmEpilogue::GetName() << "\n"
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
|
||||
@@ -139,28 +139,48 @@ float a16w4_moe_gemm(const MoeFlatmmHostArgs& args, const ck_tile::stream_config
|
||||
|
||||
constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ComputeDataType,
|
||||
ComputeDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false,
|
||||
1,
|
||||
FlatmmConfig::TiledMMAPermuteN,
|
||||
BlockedXDLN_PerWarp>>;
|
||||
using GemmEpilogue = std::conditional_t<
|
||||
FlatmmConfig::TiledMMAPermuteN,
|
||||
ck_tile::PermuteNEpilogue<
|
||||
ck_tile::PermuteNEpilogueProblem<ComputeDataType,
|
||||
ComputeDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
false,
|
||||
1>>,
|
||||
ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ComputeDataType,
|
||||
ComputeDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false,
|
||||
1,
|
||||
BlockedXDLN_PerWarp>>>;
|
||||
|
||||
using CodegenFlatmmPipeline = std::conditional_t<
|
||||
MXFP4_Pipeline,
|
||||
|
||||
@@ -108,28 +108,48 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>&
|
||||
using CodegenFlatmmPipeline =
|
||||
ck_tile::F16xMXF4FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ComputeDataType,
|
||||
ComputeDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false, // FixedVectorSize
|
||||
1, // VectorSizeC
|
||||
FlatmmConfig::TiledMMAPermuteN,
|
||||
BlockedXDLN_PerWarp>>;
|
||||
using GemmEpilogue = std::conditional_t<
|
||||
FlatmmConfig::TiledMMAPermuteN,
|
||||
ck_tile::PermuteNEpilogue<
|
||||
ck_tile::PermuteNEpilogueProblem<ComputeDataType,
|
||||
ComputeDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
false, // FixedVectorSize
|
||||
1>>, // VectorSizeC
|
||||
ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ComputeDataType,
|
||||
ComputeDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false, // FixedVectorSize
|
||||
1, // VectorSizeC
|
||||
BlockedXDLN_PerWarp>>>;
|
||||
|
||||
using Kernel =
|
||||
ck_tile::F16xMXF4FlatmmKernel<TilePartitioner, CodegenFlatmmPipeline, GemmEpilogue>;
|
||||
|
||||
@@ -163,28 +163,48 @@ float moe_gemm(const ck_tile::MoeFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
? 2
|
||||
: 1; // determined by scale shuffle pattern
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false,
|
||||
1,
|
||||
FlatmmConfig::TiledMMAPermuteN,
|
||||
BlockedXDLN_PerWarp>>;
|
||||
using GemmEpilogue = std::conditional_t<
|
||||
FlatmmConfig::TiledMMAPermuteN,
|
||||
ck_tile::PermuteNEpilogue<
|
||||
ck_tile::PermuteNEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
false,
|
||||
1>>,
|
||||
ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false,
|
||||
1,
|
||||
BlockedXDLN_PerWarp>>>;
|
||||
|
||||
using CodegenFlatmmPipeline =
|
||||
ck_tile::MoeFlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
|
||||
@@ -84,7 +84,26 @@ float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<FlatmmShape,
|
||||
FlatmmConfig::TileParitionerGroupNum,
|
||||
FlatmmConfig::TileParitionerM01>;
|
||||
using GemmEpilogue =
|
||||
using GemmEpilogue = std::conditional_t<
|
||||
FlatmmConfig::TiledMMAPermuteN,
|
||||
ck_tile::PermuteNEpilogue<ck_tile::PermuteNEpilogueProblem<ComputeDataType,
|
||||
ComputeDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
MXPipelineProblem::TransposeC,
|
||||
false, // FixedVectorSize
|
||||
1>>, // VectorSizeC
|
||||
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<ComputeDataType,
|
||||
ComputeDataType,
|
||||
DsDatatype,
|
||||
@@ -104,8 +123,7 @@ float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false, // FixedVectorSize
|
||||
1, // VectorSizeC
|
||||
FlatmmConfig::TiledMMAPermuteN,
|
||||
BlockedXDLN_PerWarp>>;
|
||||
BlockedXDLN_PerWarp>>>;
|
||||
|
||||
using Kernel = ck_tile::MXFlatmmKernel<TilePartitioner, MXFlatmmPipeline, GemmEpilogue>;
|
||||
|
||||
|
||||
@@ -207,27 +207,44 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
|
||||
printf(
|
||||
"TiledPermuteN: %d (QuantGroupSize::kN=%d)\n", TiledPermuteN, BQuantGroupSize::kN);
|
||||
}
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<typename PipelineProblem::ComputeDataType,
|
||||
typename PipelineProblem::ComputeDataType,
|
||||
ck_tile::tuple<>,
|
||||
typename TypeConfig::AccDataType,
|
||||
typename TypeConfig::CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
transpose_c,
|
||||
1,
|
||||
false,
|
||||
1,
|
||||
TiledPermuteN>>;
|
||||
using GemmEpilogue = std::conditional_t<
|
||||
TiledPermuteN,
|
||||
ck_tile::PermuteNEpilogue<
|
||||
ck_tile::PermuteNEpilogueProblem<typename PipelineProblem::ComputeDataType,
|
||||
typename PipelineProblem::ComputeDataType,
|
||||
ck_tile::tuple<>,
|
||||
typename TypeConfig::AccDataType,
|
||||
typename TypeConfig::CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
transpose_c,
|
||||
false,
|
||||
1>>,
|
||||
ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<typename PipelineProblem::ComputeDataType,
|
||||
typename PipelineProblem::ComputeDataType,
|
||||
ck_tile::tuple<>,
|
||||
typename TypeConfig::AccDataType,
|
||||
typename TypeConfig::CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
CDEElementWise,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
transpose_c>>>;
|
||||
using Kernel =
|
||||
ck_tile::QuantGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue, QuantMode>;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user