update codes

This commit is contained in:
mtgu0705
2025-08-30 03:19:07 -05:00
parent 9c37e55d13
commit 16993acd1d
9 changed files with 2095 additions and 88 deletions

View File

@@ -11,7 +11,7 @@
#include <type_traits>
#include "ck_tile/host.hpp"
#include "mx_prec_flatmm.hpp"
#include "mx_flatmm.hpp"
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
@@ -99,17 +99,17 @@ float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern
using CodegenPipelineProblem = ck_tile::F16xMXF4FlatmmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenFlatmmShape,
CodegenGemmTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using CodegenPipelineProblem = ck_tile::MXFlatmmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenFlatmmShape,
CodegenGemmTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using CodegenFlatmmPipeline =
ck_tile::F16xMXF4FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
using CodegenMXFlatmmPipeline =
ck_tile::MXF4FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ComputeDataType,
@@ -137,7 +137,7 @@ float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
BlockedXDLN_PerWarp>>;
using Kernel =
ck_tile::F16xMXF4FlatmmKernel<TilePartitioner, CodegenFlatmmPipeline, GemmEpilogue>;
ck_tile::MXFlatmmKernel<TilePartitioner, CodegenMXFlatmmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);