mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 09:45:56 +00:00
WIP
This commit is contained in:
@@ -92,8 +92,8 @@ float invoke_mx_gemm(ck_tile::DeviceMem& a_dev_buf,
|
||||
MXGemmTraits,
|
||||
GemmConfig::Scheduler>;
|
||||
|
||||
// Use the new comp_async pipeline with MX scaling support
|
||||
using MXGemmPipeline = ck_tile::GemmPipelineAgBgCrCompAsync<MXPipelineProblem>;
|
||||
// Use the new MX comp_async pipeline with MX scaling support
|
||||
using MXGemmPipeline = ck_tile::MXGemmPipelineAgBgCrCompAsync<MXPipelineProblem>;
|
||||
|
||||
// Simplified invocation - comp_async handles hot loop and tail internally
|
||||
auto invoke_splitk_path = [&](auto split_k_) {
|
||||
|
||||
@@ -25,6 +25,15 @@ struct MXGemmPipelineProblem : ck_tile::GemmPipelineProblem<ADataType, BDataType
|
||||
static constexpr auto Scheduler = Scheduler_;
|
||||
};
|
||||
|
||||
// Epilogue wrapper that adds MemoryOperation member for MX GEMM kernel compatibility
|
||||
template <typename BaseEpilogue_, ck_tile::memory_operation_enum MemOp_>
|
||||
struct MXGemmEpilogueWrapper : BaseEpilogue_
|
||||
{
|
||||
static constexpr ck_tile::memory_operation_enum MemoryOperation = MemOp_;
|
||||
using BaseEpilogue_::BaseEpilogue_;
|
||||
using BaseEpilogue_::operator();
|
||||
};
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
@@ -75,15 +84,15 @@ float mx_gemm_calc(const MXGemmHostArgs<ScaleM, ScaleN>& args,
|
||||
MXGemmTraits,
|
||||
scheduler>;
|
||||
|
||||
// Use the new comp_async pipeline with MX scaling support
|
||||
using MXGemmPipeline = ck_tile::GemmPipelineAgBgCrCompAsync<MXPipelineProblem>;
|
||||
// Use the new MX comp_async pipeline with MX scaling support
|
||||
using MXGemmPipeline = ck_tile::MXGemmPipelineAgBgCrCompAsync<MXPipelineProblem>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
GemmConfig::TileParitionerGroupNum,
|
||||
GemmConfig::TileParitionerM01>;
|
||||
|
||||
using GemmEpilogue =
|
||||
using BaseEpilogue =
|
||||
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<ComputeDataType,
|
||||
ComputeDataType,
|
||||
ck_tile::tuple<>, // DsDataType
|
||||
@@ -100,11 +109,14 @@ float mx_gemm_calc(const MXGemmHostArgs<ScaleM, ScaleN>& args,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
MXPipelineProblem::TransposeC,
|
||||
memory_operation,
|
||||
GemmConfig::NumWaveGroups,
|
||||
GemmConfig::NumWaveGroups, // kNumWaveGroups
|
||||
false, // FixedVectorSize
|
||||
1, // VectorSizeC
|
||||
false>>; // PermuteN
|
||||
false, // TiledMMAPermuteN
|
||||
1, // BlockedXDLN_PerWarp
|
||||
false>>; // DoubleSmemBuffer
|
||||
|
||||
using GemmEpilogue = MXGemmEpilogueWrapper<BaseEpilogue, memory_operation>;
|
||||
|
||||
using Kernel = ck_tile::MXGemmKernel<TilePartitioner, MXGemmPipeline, GemmEpilogue>;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user