[CK_TILE] MX Flatmm Split kernel instances (#3207)

* [CK_TILE] MX Flatmm Split kernel instances

* Fix flatmm example compile
This commit is contained in:
Yi DING
2025-11-18 13:46:30 +08:00
committed by GitHub
parent 92498464f6
commit b6720531de
12 changed files with 371 additions and 276 deletions

View File

@@ -14,7 +14,12 @@ if(has_supported_gpu)
add_executable(tile_example_moe_flatmm EXCLUDE_FROM_ALL moe_flatmm.cpp)
add_executable(tile_example_a16w4_moe_flatmm EXCLUDE_FROM_ALL mixed_prec/a16w4_moe_flatmm.cpp)
add_executable(tile_example_grouped_flatmm EXCLUDE_FROM_ALL grouped_flatmm.cpp)
add_executable(tile_example_mx_flatmm EXCLUDE_FROM_ALL mxgemm/mx_flatmm.cpp) # TODO: 950 only
include(mxgemm/mx_flatmm_instance.cmake)
mx_flatmm_instance_generate(EXAMPLE_MX_FLATMM_FILES)
message(STATUS "Generated MX FlatMM kernel files: ${EXAMPLE_MX_FLATMM_FILES}")
add_executable(tile_example_mx_flatmm EXCLUDE_FROM_ALL mxgemm/mx_flatmm.cpp ${EXAMPLE_MX_FLATMM_FILES})
target_include_directories(tile_example_mx_flatmm PRIVATE mxgemm)
set(EXAMPLE_FLATMM_COMPILE_OPTIONS)
set(EXAMPLE_MOE_FLATMM_COMPILE_OPTIONS)

View File

@@ -29,7 +29,7 @@ static constexpr inline auto is_row_major(Layout layout_)
}
template <typename FlatmmConfig, typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t)
auto flatmm_shuffle_b(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];

View File

@@ -20,211 +20,6 @@ static constexpr inline auto is_row_major(Layout layout_)
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
template <typename FlatmmConfig,
typename ADataType,
typename BDataType,
typename DsDatatype,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ScaleM,
typename ScaleN,
bool persistent,
typename CDEElementWise>
float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
const ck_tile::stream_config& s)
{
using CodegenFlatmmShape = ck_tile::TileGemmShape<
ck_tile::sequence<FlatmmConfig::M_Tile, FlatmmConfig::N_Tile, FlatmmConfig::K_Tile>,
ck_tile::sequence<FlatmmConfig::M_Warp, FlatmmConfig::N_Warp, FlatmmConfig::K_Warp>,
ck_tile::sequence<FlatmmConfig::M_Warp_Tile,
FlatmmConfig::N_Warp_Tile,
FlatmmConfig::K_Warp_Tile>>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<CodegenFlatmmShape,
FlatmmConfig::TileParitionerGroupNum,
FlatmmConfig::TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<FlatmmConfig::kPadM,
FlatmmConfig::kPadN,
FlatmmConfig::kPadK,
ALayout,
BLayout,
ELayout,
FlatmmConfig::NumWaveGroups>;
using CodegenGemmTraits = ck_tile::TileGemmUniversalTraits<FlatmmConfig::kPadM,
FlatmmConfig::kPadN,
FlatmmConfig::kPadK,
FlatmmConfig::DoubleSmemBuffer,
ALayout,
BLayout,
ELayout,
FlatmmConfig::TransposeC,
FlatmmConfig::UseStructuredSparsity,
persistent,
FlatmmConfig::NumWaveGroups,
true>;
using ComputeDataType = ADataType;
static_assert(sizeof(ComputeDataType) >= sizeof(BDataType),
"mixed_prec_flatmm requires ADataType is a wider type than BDataType");
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<ComputeDataType,
ComputeDataType,
AccDataType,
CodegenFlatmmShape,
Traits>;
using BaseGemmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1<GemmPipelineProblem>;
const ck_tile::index_t k_grain = args.k_batch * FlatmmConfig::K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * FlatmmConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = FlatmmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern
using CodegenPipelineProblem = ck_tile::MXFlatmmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenFlatmmShape,
CodegenGemmTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using CodegenMXFlatmmPipeline =
ck_tile::MXF4FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ComputeDataType,
ComputeDataType,
DsDatatype,
AccDataType,
CDataType,
DsLayout,
ELayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
FlatmmConfig::M_Warp,
FlatmmConfig::N_Warp,
FlatmmConfig::M_Warp_Tile,
FlatmmConfig::N_Warp_Tile,
FlatmmConfig::K_Warp_Tile,
CodegenPipelineProblem::TransposeC,
memory_operation,
FlatmmConfig::NumWaveGroups,
false, // FixedVectorSize
1, // VectorSizeC
FlatmmConfig::TiledMMAPermuteN,
BlockedXDLN_PerWarp>>;
using Kernel =
ck_tile::MXFlatmmKernel<TilePartitioner, CodegenMXFlatmmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(kargs);
constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:" << CodegenFlatmmShape::GetName() << "\n"
<< "Shape: " << CodegenFlatmmShape::GetName() << "\n"
<< "problem: " << CodegenPipelineProblem::GetName() << "\n"
<< "pipeline: " << CodegenMXFlatmmPipeline::GetName() << "\n"
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
// Declare rotating_mem_ptr here so it stays in scope until it is needed
std::unique_ptr<ck_tile::RotatingMemWrapper<ADataType, BDataType>> rotating_mem_ptr;
std::function<void()> preprocess;
auto clear_gemm_output = [&]() {
if(args.k_batch > 1)
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
};
if(s.flush_cache_)
{
std::cout << "Flushing cache..." << std::endl;
constexpr ck_tile::index_t APackedSize = ck_tile::numeric_traits<ADataType>::PackedSize;
constexpr ck_tile::index_t BPackedSize = ck_tile::numeric_traits<BDataType>::PackedSize;
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
rotating_mem_ptr = std::make_unique<ck_tile::RotatingMemWrapper<ADataType, BDataType>>(
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
rotating_mem_ptr->Print();
preprocess = [&]() {
ck_tile::flush_icache();
rotating_mem_ptr->Next();
clear_gemm_output();
};
}
else
{
preprocess = clear_gemm_output;
}
ave_time = ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(args.k_batch == 1)
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
return ave_time;
}
template <typename FlatmmConfig,
typename ADataType,
typename BDataType,
@@ -269,21 +64,66 @@ float invoke_mx_flatmm(ck_tile::DeviceMem& a_dev_buf,
scale_a,
scale_b};
float ave_time = mx_flatmm_calc<FlatmmConfig,
ADataType,
BDataType,
DsDatatype,
AccDataType,
CDataType,
ALayout,
BLayout,
DsLayout,
CLayout,
ScaleA,
ScaleB,
UsePersistentKernel,
CDEElementWise>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
using FlatmmShape = ck_tile::TileGemmShape<
ck_tile::sequence<FlatmmConfig::M_Tile, FlatmmConfig::N_Tile, FlatmmConfig::K_Tile>,
ck_tile::sequence<FlatmmConfig::M_Warp, FlatmmConfig::N_Warp, FlatmmConfig::K_Warp>,
ck_tile::sequence<FlatmmConfig::M_Warp_Tile,
FlatmmConfig::N_Warp_Tile,
FlatmmConfig::K_Warp_Tile>>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<FlatmmShape,
FlatmmConfig::TileParitionerGroupNum,
FlatmmConfig::TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<FlatmmConfig::kPadM,
FlatmmConfig::kPadN,
FlatmmConfig::kPadK,
ALayout,
BLayout,
CLayout,
FlatmmConfig::NumWaveGroups>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, FlatmmShape, Traits>;
using BaseFlatmmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1<GemmPipelineProblem>;
const ck_tile::index_t k_grain = args.k_batch * FlatmmConfig::K_Tile;
const ck_tile::index_t k_split = (K + k_grain - 1) / k_grain * FlatmmConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(k_split);
const bool has_hot_loop = BaseFlatmmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseFlatmmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time = BaseFlatmmPipeline::template TailHandler<true>(
[&](auto has_hot_loop_, auto tail_num_) {
constexpr auto has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_num_v = tail_num_.value;
auto invoke_splitk_path = [&](auto split_k_) {
return mx_flatmm_calc<FlatmmConfig,
ADataType,
BDataType,
DsDatatype,
AccDataType,
CDataType,
ALayout,
BLayout,
DsLayout,
CLayout,
ScaleA,
ScaleB,
UsePersistentKernel,
CDEElementWise,
split_k_.value,
has_hot_loop_v,
tail_num_v>(
args,
ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
};
return (args.k_batch == 1) ? invoke_splitk_path(std::false_type{})
: invoke_splitk_path(std::true_type{});
},
has_hot_loop,
tail_num);
constexpr int APackedSize = ck_tile::numeric_traits<ADataType>::PackedSize;
constexpr int BPackedSize = ck_tile::numeric_traits<BDataType>::PackedSize;
@@ -297,8 +137,8 @@ float invoke_mx_flatmm(ck_tile::DeviceMem& a_dev_buf,
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Run MXFP4_Flatmm kernel " //
<< " M =" << M << " N =" << N << " K =" << K << " StrideA =" << stride_A
<< " StrideB =" << stride_B << " StrideC =" << stride_C << " : " << ave_time
<< " M = " << M << " N = " << N << " K = " << K << " StrideA = " << stride_A
<< " StrideB = " << stride_B << " StrideC = " << stride_C << " : " << ave_time
<< " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
return ave_time;
@@ -441,21 +281,13 @@ int run_mx_flatmm_example(int argc, char* argv[])
if(mx_prec == "fp4xfp4")
{
if(persistent_opt == 0)
{
run_mx_flatmm_with_layouts<ck_tile::pk_fp4_t,
ck_tile::pk_fp4_t,
ck_tile::fp16_t,
FlatmmConfig,
false>(argc, argv, Row{}, Col{}, Row{});
}
return run_mx_flatmm_with_layouts<ck_tile::pk_fp4_t,
ck_tile::pk_fp4_t,
ck_tile::fp16_t,
FlatmmConfig,
false>(argc, argv, Row{}, Col{}, Row{});
else
{
run_mx_flatmm_with_layouts<ck_tile::pk_fp4_t,
ck_tile::pk_fp4_t,
ck_tile::fp16_t,
FlatmmConfig,
true>(argc, argv, Row{}, Col{}, Row{});
}
throw std::runtime_error("Only non-persistent kernels are supported currently!");
}
else if(mx_prec == "fp6xfp6")
{
@@ -487,7 +319,7 @@ int main(int argc, char* argv[])
int warp_tile = arg_parser.get_int("warp_tile");
if(warp_tile == 0)
{
return !run_mx_flatmm_example<MXfp4_FlatmmConfig16>(argc, argv);
return run_mx_flatmm_example<MXfp4_FlatmmConfig16>(argc, argv);
}
else if(warp_tile == 1)
{

View File

@@ -0,0 +1,27 @@
function(mx_flatmm_instance_generate FILE_LIST)
set(FLATMM_CONFIG MXfp4_FlatmmConfig16)
set(A_DATA_TYPE FP4)
set(B_DATA_TYPE FP4)
set(C_DATA_TYPE FP16)
set(A_LAYOUT ROW)
set(B_LAYOUT COL)
set(C_LAYOUT ROW)
# foreach(PERSISTENT false true)
# TODO: Persistent kernels are disabled due to compilation failures with some LLVM versions.
foreach(PERSISTENT false)
foreach(SPLIT_K false true)
foreach(HAS_HOT_LOOP false true)
foreach(TAIL_NUMBER ODD EVEN)
set(KERNEL_FILE mxgemm/mx_flatmm_instance_${PERSISTENT}_${SPLIT_K}_${HAS_HOT_LOOP}_${TAIL_NUMBER}.cpp)
configure_file(
${CMAKE_CURRENT_SOURCE_DIR}/mxgemm/mx_flatmm_instance.cpp.in
${CMAKE_CURRENT_BINARY_DIR}/${KERNEL_FILE}
@ONLY)
list(APPEND ${FILE_LIST} ${CMAKE_CURRENT_BINARY_DIR}/${KERNEL_FILE})
endforeach()
endforeach()
endforeach()
endforeach()
set(${FILE_LIST} ${${FILE_LIST}} PARENT_SCOPE)
endfunction()

View File

@@ -0,0 +1,53 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "mx_flatmm_instance.hpp"
// clang-format off
#define FLATMM_CONFIG @FLATMM_CONFIG@
#define A_DATA_TYPE @A_DATA_TYPE@
#define B_DATA_TYPE @B_DATA_TYPE@
#define C_DATA_TYPE @C_DATA_TYPE@
#define A_LAYOUT @A_LAYOUT@
#define B_LAYOUT @B_LAYOUT@
#define C_LAYOUT @C_LAYOUT@
#define PERSISTENT @PERSISTENT@
#define SPLIT_K @SPLIT_K@
#define HAS_HOT_LOOP @HAS_HOT_LOOP@
#define TAIL_NUMBER @TAIL_NUMBER@
// clang-format on
using FP4 = ck_tile::pk_fp4_t;
using FP16 = ck_tile::fp16_t;
using BF16 = ck_tile::bf16_t;
using ROW = ck_tile::tensor_layout::gemm::RowMajor;
using COL = ck_tile::tensor_layout::gemm::ColumnMajor;
inline constexpr auto ODD = ck_tile::TailNumber::Odd;
inline constexpr auto EVEN = ck_tile::TailNumber::Even;
inline constexpr int ScaleGranularityM = 1;
inline constexpr int ScaleGranularityN = 1;
inline constexpr int ScaleGranularityK = 32;
using ScaleM = ck_tile::FlatmmScalePointer<ScaleGranularityM, ScaleGranularityK>;
using ScaleN = ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK>;
template float mx_flatmm_calc<FLATMM_CONFIG,
A_DATA_TYPE,
B_DATA_TYPE,
/*DsDatatype*/ ck_tile::tuple<>,
/*AccDataType*/ float,
C_DATA_TYPE,
A_LAYOUT,
B_LAYOUT,
/*DsLayout*/ ck_tile::tuple<>,
C_LAYOUT,
ScaleM,
ScaleN,
PERSISTENT,
/*CDEElementWise*/ ck_tile::element_wise::PassThrough,
SPLIT_K,
HAS_HOT_LOOP,
TAIL_NUMBER>(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
const ck_tile::stream_config& s);

View File

@@ -0,0 +1,172 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <ostream>
#include <string>
#include <type_traits>
#include "ck_tile/host.hpp"
#include "mx_flatmm.hpp"
template <typename Layout>
using is_row_major_t = ck_tile::bool_constant<
std::is_same_v<ck_tile::remove_cvref_t<Layout>, ck_tile::tensor_layout::gemm::RowMajor>>;
template <typename FlatmmConfig,
typename ADataType,
typename BDataType,
typename DsDatatype,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename ScaleM,
typename ScaleN,
bool persistent,
typename CDEElementWise,
bool Splitk,
bool HasHotLoop,
ck_tile::TailNumber TailNum>
float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
const ck_tile::stream_config& s)
{
using FlatmmShape = ck_tile::TileGemmShape<
ck_tile::sequence<FlatmmConfig::M_Tile, FlatmmConfig::N_Tile, FlatmmConfig::K_Tile>,
ck_tile::sequence<FlatmmConfig::M_Warp, FlatmmConfig::N_Warp, FlatmmConfig::K_Warp>,
ck_tile::sequence<FlatmmConfig::M_Warp_Tile,
FlatmmConfig::N_Warp_Tile,
FlatmmConfig::K_Warp_Tile>>;
using MXGemmTraits = ck_tile::TileGemmUniversalTraits<FlatmmConfig::kPadM,
FlatmmConfig::kPadN,
FlatmmConfig::kPadK,
FlatmmConfig::DoubleSmemBuffer,
ALayout,
BLayout,
CLayout,
FlatmmConfig::TransposeC,
FlatmmConfig::UseStructuredSparsity,
persistent,
FlatmmConfig::NumWaveGroups,
true>;
using ComputeDataType = ADataType;
static_assert(sizeof(ComputeDataType) >= sizeof(BDataType),
"mixed_prec_flatmm requires ADataType is a wider type than BDataType");
constexpr auto scheduler = FlatmmConfig::Scheduler;
constexpr auto memory_operation =
Splitk ? ck_tile::memory_operation_enum::atomic_add : ck_tile::memory_operation_enum::set;
constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern
using MXPipelineProblem = ck_tile::MXFlatmmPipelineProblem<ADataType,
BDataType,
AccDataType,
FlatmmShape,
MXGemmTraits,
scheduler,
HasHotLoop,
TailNum>;
using MXFlatmmPipeline = ck_tile::MXF4FlatmmPipelineAGmemBGmemCRegV1<MXPipelineProblem>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<FlatmmShape,
FlatmmConfig::TileParitionerGroupNum,
FlatmmConfig::TileParitionerM01>;
using GemmEpilogue =
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<ComputeDataType,
ComputeDataType,
DsDatatype,
AccDataType,
CDataType,
DsLayout,
CLayout,
CDEElementWise,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
FlatmmConfig::M_Warp,
FlatmmConfig::N_Warp,
FlatmmConfig::M_Warp_Tile,
FlatmmConfig::N_Warp_Tile,
FlatmmConfig::K_Warp_Tile,
MXPipelineProblem::TransposeC,
memory_operation,
FlatmmConfig::NumWaveGroups,
false, // FixedVectorSize
1, // VectorSizeC
FlatmmConfig::TiledMMAPermuteN,
BlockedXDLN_PerWarp>>;
using Kernel = ck_tile::MXFlatmmKernel<TilePartitioner, MXFlatmmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(kargs);
constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:" << FlatmmShape::GetName() << "\n"
<< "Shape: " << FlatmmShape::GetName() << "\n"
<< "problem: " << MXPipelineProblem::GetName() << "\n"
<< "pipeline: " << MXFlatmmPipeline::GetName() << "\n"
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
// Declare rotating_mem_ptr here so it stays in scope until it is needed
std::unique_ptr<ck_tile::RotatingMemWrapper<ADataType, BDataType>> rotating_mem_ptr;
std::function<void()> preprocess;
auto clear_gemm_output = [&]() {
if(args.k_batch > 1)
hipGetErrorString(
hipMemsetAsync(args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
};
if(s.flush_cache_)
{
std::cout << "Flushing cache..." << std::endl;
constexpr ck_tile::index_t APackedSize = ck_tile::numeric_traits<ADataType>::PackedSize;
constexpr ck_tile::index_t BPackedSize = ck_tile::numeric_traits<BDataType>::PackedSize;
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
args.M, args.K, args.stride_A, is_row_major_t<ALayout>{}));
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
args.K, args.N, args.stride_B, is_row_major_t<BLayout>{}));
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
rotating_mem_ptr = std::make_unique<ck_tile::RotatingMemWrapper<ADataType, BDataType>>(
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
rotating_mem_ptr->Print();
preprocess = [&]() {
ck_tile::flush_icache();
rotating_mem_ptr->Next();
clear_gemm_output();
};
}
else
{
preprocess = clear_gemm_output;
}
return ck_tile::launch_kernel_time_mask(
s,
preprocess,
ck_tile::make_kernel<FlatmmConfig::kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}

View File

@@ -38,3 +38,23 @@ struct MXfp4_FlatmmConfig16
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = false;
};
template <typename FlatmmConfig,
typename ADataType,
typename BDataType,
typename DsDatatype,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ScaleM,
typename ScaleN,
bool persistent,
typename CDEElementWise,
bool Splitk,
bool HasHotLoop,
ck_tile::TailNumber TailNum>
float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs<ScaleM, ScaleN>& args,
const ck_tile::stream_config& s);

View File

@@ -163,5 +163,5 @@ int run_mx_flatmm_with_layouts(int argc,
std::cout << "The GPU veification result is: " << (pass ? "correct" : "fail") << std::endl;
}
return pass;
return pass ? 0 : -1;
}

View File

@@ -3,23 +3,6 @@
#pragma once
// mfma_type, 0:32x32, 1:16x16
template <typename FlatmmConfig, typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
constexpr int divisor = FlatmmConfig::N_Warp_Tile == 32 ? 2 : 4;
ck_tile::HostTensor<T> t_view({n_ / FlatmmConfig::N_Warp_Tile,
FlatmmConfig::N_Warp_Tile,
k_ / FlatmmConfig::K_Warp_Tile,
divisor,
FlatmmConfig::K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,

View File

@@ -114,7 +114,7 @@ int run_moe_gemm_example_with_layouts(int argc,
ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_m_k_tensor);
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_k_n_tensor);
auto b_shuffle_host = shuffle_b<FlatmmConfig>(b_k_n_tensor);
auto b_shuffle_host = flatmm_shuffle_b<FlatmmConfig>(b_k_n_tensor);
std::cout << "moe_flatmm:" //
<< "\n num_experts: " << experts << "\n num_tokens: " << num_tokens