add moe_flatmm

This commit is contained in:
Feng Shijie
2025-08-06 08:33:33 +00:00
parent 90e910f3a7
commit 6d3cbc7c0e
12 changed files with 3447 additions and 0 deletions

View File

@@ -0,0 +1,11 @@
add_executable(tile_example_moe_flatmm EXCLUDE_FROM_ALL moe_flatmm.cpp)
set(EXAMPLE_MOE_FLATMM_COMPILE_OPTIONS)
list(APPEND EXAMPLE_MOE_FLATMM_COMPILE_OPTIONS -Wno-nrvo -Wno-unused-variable -Wno-unused-parameter -Wno-unused-local-typedef -Wno-float-equal)
if(CK_USE_OCP_FP8)
list(APPEND EXAMPLE_MOE_FLATMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
list(APPEND EXAMPLE_MOE_FLATMM_COMPILE_OPTIONS --save-temps)
target_compile_options(tile_example_moe_flatmm PRIVATE ${EXAMPLE_MOE_FLATMM_COMPILE_OPTIONS})

View File

@@ -0,0 +1,461 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <ostream>
#include <string>
#include <tuple>
#include <memory>
#include "moe_flatmm.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/flatmm.hpp"
#include "ck_tile/ops/moe_flatmm.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host/reference/reference_moe_gemm.hpp"
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
template <typename FlatmmConfig, typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
constexpr int N_Warp_Tile = FlatmmConfig::N_Warp_Tile;
constexpr int N_Warp = FlatmmConfig::N_Warp;
constexpr int KPerLane = FlatmmConfig::K_Warp_Tile / (64 / N_Warp_Tile);
ck_tile::HostTensor<T> t_view({n_ / N_Warp_Tile,
N_Warp_Tile,
k_ / (64 * KPerLane / N_Warp_Tile),
64 / N_Warp_Tile,
KPerLane});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
// gemm1
// operand-A = [num_token, d_model]
// operand-B = [num_expert, hidden, d_model]
// operand-C = [num_token, topk, hidden]
// gemm2
// operand-A = [num_token, topk, hidden]
// operand-B = [num_expert, d_model, hidden]
// operand-C = [num_token, d_model]
template <typename FlatmmConfig,
typename ADataType,
typename BDataType,
typename DsDatatype,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ck_tile::MoeFlatmmKind moe_kind = ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only,
typename CDEElementWise = ck_tile::element_wise::PassThrough,
typename ScaleM,
typename ScaleN>
float moe_gemm(const ck_tile::MoeFlatmmHostArgs<ScaleM, ScaleN>& args,
const ck_tile::stream_config& s)
{
using CodegenFlatmmShape = ck_tile::TileGemmShape<
ck_tile::sequence<FlatmmConfig::M_Tile, FlatmmConfig::N_Tile, FlatmmConfig::K_Tile>,
ck_tile::sequence<FlatmmConfig::M_Warp, FlatmmConfig::N_Warp, FlatmmConfig::K_Warp>,
ck_tile::sequence<FlatmmConfig::M_Warp_Tile,
FlatmmConfig::N_Warp_Tile,
FlatmmConfig::K_Warp_Tile>>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<CodegenFlatmmShape,
FlatmmConfig::TileParitionerGroupNum,
FlatmmConfig::TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<FlatmmConfig::kPadM,
FlatmmConfig::kPadN,
FlatmmConfig::kPadK,
ALayout,
BLayout,
ELayout,
FlatmmConfig::NumWaveGroups>;
using CodegenGemmTraits = ck_tile::TileGemmUniversalTraits<FlatmmConfig::kPadM,
FlatmmConfig::kPadN,
FlatmmConfig::kPadK,
FlatmmConfig::DoubleSmemBuffer,
ALayout,
BLayout,
ELayout,
FlatmmConfig::TransposeC,
FlatmmConfig::UseStructuredSparsity,
false, // UsePersistentKernel_
FlatmmConfig::NumWaveGroups,
true>; // Preshuffle_
if constexpr(moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up)
{
static_assert(
FlatmmConfig::N_Tile % (FlatmmConfig::N_Warp * FlatmmConfig::N_Warp_Tile * 2) == 0,
"requires NRepeat is multiple of 2 for FFN_gemm1_gate_up");
}
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenFlatmmShape, Traits>;
using BaseGemmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1<GemmPipelineProblem>;
const ck_tile::index_t k_grain = args.k_batch * FlatmmConfig::K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * FlatmmConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = FlatmmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenFlatmmShape,
CodegenGemmTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDatatype,
AccDataType,
CDataType,
DsLayout,
ELayout,
CDEElementWise,
CodegenPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
FlatmmConfig::M_Warp,
FlatmmConfig::N_Warp,
FlatmmConfig::M_Warp_Tile,
FlatmmConfig::N_Warp_Tile,
FlatmmConfig::K_Warp_Tile,
CodegenPipelineProblem::TransposeC,
memory_operation,
FlatmmConfig::NumWaveGroups,
false,
1,
FlatmmConfig::TiledMMAPermuteN>>;
using CodegenFlatmmPipeline =
ck_tile::MoeFlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
using Kernel = ck_tile::
MoeFlatmmKernel<TilePartitioner, CodegenFlatmmPipeline, GemmEpilogue, moe_kind>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(kargs);
constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:" << CodegenFlatmmShape::GetName() << "\n"
<< "Shape: " << CodegenFlatmmShape::GetName() << "\n"
<< "problem: " << CodegenPipelineProblem::GetName() << "\n"
<< "pipeline: " << CodegenFlatmmPipeline::GetName() << "\n"
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
if(s.flush_cache_)
{
std::cout << "Flushing cache..." << std::endl;
static constexpr ck_tile::index_t APackedSize =
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
static constexpr ck_tile::index_t BPackedSize =
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm2 ? args.NumTokens * args.TopK
: args.NumTokens,
args.K,
args.stride_A,
is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
args.K, args.N * args.NumExperts, args.stride_B, is_row_major(BLayout{})));
const int outputN =
moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up ? args.N / 2 : args.N;
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck_tile::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if(moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm2)
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.NumTokens * args.N * sizeof(CDataType), s.stream_id_));
else if(args.k_batch > 1)
hipGetErrorString(
hipMemsetAsync(args.e_ptr,
0,
args.NumTokens * args.TopK * outputN * sizeof(CDataType),
s.stream_id_));
};
ave_time = ck_tile::launch_kernel_preprocess(
s,
run_flush_cache,
ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
}
else
{
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
}
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(args.k_batch == 1)
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
return ave_time;
}
#include "run_moe_flatmm_example.inc"
template <template <typename PreType> typename FlatmmConfig>
int run_moe_flatmm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
return -1;
}
const std::string a_layout = arg_parser.get_str("a_layout");
const std::string b_layout = arg_parser.get_str("b_layout");
const std::string prec_type = arg_parser.get_str("prec");
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
if(a_layout == "R" && b_layout == "C")
{
const std::string gemm_kind = arg_parser.get_str("gemm_kind");
if(gemm_kind == "gemm1_gate_up")
{
if(prec_type == "fp8")
{
return run_moe_gemm_example_with_layouts<
ck_tile::fp8_t,
FlatmmConfig<ck_tile::fp8_t>,
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
}
else if(prec_type == "bf8")
{
return run_moe_gemm_example_with_layouts<
ck_tile::bf8_t,
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
}
else if(prec_type == "bf16")
{
return run_moe_gemm_example_with_layouts<
ck_tile::bfloat16_t,
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
}
else if(prec_type == "fp16")
{
return run_moe_gemm_example_with_layouts<
ck_tile::half_t,
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported precision type for gemm1_gate_up!");
}
}
else if(gemm_kind == "gemm1_gate_only")
{
if(prec_type == "fp8")
{
return run_moe_gemm_example_with_layouts<
ck_tile::fp8_t,
FlatmmConfig<ck_tile::fp8_t>,
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only>(argc, argv, Row{}, Col{}, Row{});
}
else if(prec_type == "bf8")
{
return run_moe_gemm_example_with_layouts<
ck_tile::bf8_t,
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only>(argc, argv, Row{}, Col{}, Row{});
}
else if(prec_type == "bf16")
{
return run_moe_gemm_example_with_layouts<
ck_tile::bfloat16_t,
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only>(argc, argv, Row{}, Col{}, Row{});
}
else if(prec_type == "fp16")
{
return run_moe_gemm_example_with_layouts<
ck_tile::half_t,
FlatmmConfig<ck_tile::half_t>,
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only>(argc, argv, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported precision type for gemm1_gate_up!");
}
}
else if(gemm_kind == "gemm2")
{
if(prec_type == "fp8")
{
return run_moe_gemm_example_with_layouts<ck_tile::fp8_t,
FlatmmConfig<ck_tile::fp8_t>,
ck_tile::MoeFlatmmKind::kFFN_gemm2>(
argc, argv, Row{}, Col{}, Row{});
}
else if(prec_type == "bf8")
{
return run_moe_gemm_example_with_layouts<ck_tile::bf8_t,
ck_tile::MoeFlatmmKind::kFFN_gemm2>(
argc, argv, Row{}, Col{}, Row{});
}
else if(prec_type == "bf16")
{
return run_moe_gemm_example_with_layouts<ck_tile::bfloat16_t,
ck_tile::MoeFlatmmKind::kFFN_gemm2>(
argc, argv, Row{}, Col{}, Row{});
}
else if(prec_type == "fp16")
{
return run_moe_gemm_example_with_layouts<ck_tile::half_t,
ck_tile::MoeFlatmmKind::kFFN_gemm2>(
argc, argv, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported precision type for gemm1_gate_up!");
}
}
else
{
throw std::runtime_error("Unrecoginized gemm_kind parameter, only accept value "
"[gemm1_gate_only | gemm1_gate_up | gemm2]");
}
}
else
{
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
}
return -1;
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return EXIT_FAILURE;
try
{
int warp_tile = arg_parser.get_int("warp_tile");
if(warp_tile == 0)
{
return !run_moe_flatmm_example<FlatmmConfig16>(argc, argv);
}
else if(warp_tile == 1)
{
return !run_moe_flatmm_example<FlatmmConfig32>(argc, argv);
}
else if(warp_tile == 2)
{
return !run_moe_flatmm_example<FlatmmConfig16_950>(argc, argv);
}
else
{
return !run_moe_flatmm_example<FlatmmConfig32_950>(argc, argv);
}
}
catch(const std::runtime_error& e)
{
std::cerr << "Runtime error: " << e.what() << '\n';
return EXIT_FAILURE;
}
}

View File

@@ -0,0 +1,202 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include <tuple>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/moe_flatmm.hpp"
template <typename DataType>
struct FlatmmConfig32
{
static constexpr ck_tile::index_t M_Tile = 64;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(DataType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 16 : 32;
static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr int kBlockPerCu = 1;
static constexpr int TileParitionerGroupNum = 8;
static constexpr int TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool DoubleSmemBuffer = false;
static constexpr bool TiledMMAPermuteN = false; // disable PermuteN when NWarpTile != 16
};
template <typename DataType>
struct FlatmmConfig32_950 : public FlatmmConfig32<DataType>
{
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 16 : 64;
};
// GEMM config with 16x16 warp tile
template <typename DataType>
struct FlatmmConfig16
{
static constexpr ck_tile::index_t M_Tile = 64;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(DataType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 32 : 64;
static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr int kBlockPerCu = 1;
static constexpr int TileParitionerGroupNum = 8;
static constexpr int TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool DoubleSmemBuffer = false;
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = false;
};
template <typename DataType>
struct FlatmmConfig16_950 : public FlatmmConfig16<DataType>
{
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(DataType);
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 32 : 128;
static constexpr int kBlockPerCu = 1;
static constexpr int N_Repeat =
N_Tile / FlatmmConfig16<DataType>::N_Warp_Tile / FlatmmConfig16<DataType>::N_Warp;
static constexpr bool TiledMMAPermuteN = false; // N_Repeat % 2 == 0;
};
template <typename ADataType>
struct GemmBasicTypeConfig;
template <>
struct GemmBasicTypeConfig<ck_tile::half_t>
{
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
// ToDo: Add more bias config to support different categories of GEMM.
};
template <>
struct GemmBasicTypeConfig<ck_tile::bf16_t>
{
using ADataType = ck_tile::bf16_t;
using BDataType = ck_tile::bf16_t;
using AccDataType = float;
using CDataType = ck_tile::bf16_t;
};
template <>
struct GemmBasicTypeConfig<ck_tile::fp8_t>
{
using ADataType = ck_tile::fp8_t;
using BDataType = ck_tile::fp8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
// ToDo: Add more bias config to support different categories of GEMM.
};
template <>
struct GemmBasicTypeConfig<ck_tile::bf8_t>
{
using ADataType = ck_tile::bf8_t;
using BDataType = ck_tile::bf8_t;
using AccDataType = float;
using CDataType = ck_tile::half_t;
};
template <typename T>
struct DataTypeTraits;
template <>
struct DataTypeTraits<ck_tile::fp8_t>
{
static constexpr const char* name = "fp8";
};
template <>
struct DataTypeTraits<ck_tile::bf8_t>
{
static constexpr const char* name = "bf8";
};
template <>
struct DataTypeTraits<float>
{
static constexpr const char* name = "fp32";
};
template <>
struct DataTypeTraits<double>
{
static constexpr const char* name = "fp64";
};
template <>
struct DataTypeTraits<ck_tile::half_t>
{
static constexpr const char* name = "fp16";
};
template <typename T>
struct is_8bit_type
: std::bool_constant<std::is_same_v<T, ck_tile::fp8_t> || std::is_same_v<T, ck_tile::bf8_t>>
{
};
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("experts", "8", "Num of experts - 8 by default")
.insert("NumTokens", "128", "M dimensions - 128 by default.")
.insert("TopK", "3", "Top K - 3 by default.")
.insert("N", "4096", "N dimensions - 4096 by default.")
.insert("K", "4096", "K dimensions - 4096 by default.")
.insert("stride_A", "", "Tensor A strides - it is empty by default.")
.insert("stride_B", "", "Tensor B strides - it is empty by default.")
.insert("stride_C", "", "Tensor C strides - it is empty by default.")
.insert("a_layout", "R", "A tensor data layout - Row by default.")
.insert("b_layout", "C", "B tensor data layout - Col by default.")
.insert("c_layout", "R", "C tensor data layout - Row by default.")
.insert("gemm_kind",
"gemm1_gate_only",
"Gemm kind in FFN network [gemm1_gate_only | gemm1_gate_up | gemm2] - "
"gemm1_gate_only by default.")
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert(
"warp_tile", "0", "0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)")
.insert("repeat", "10", "number of iterations to benchmark the kernel.");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}

View File

@@ -0,0 +1,344 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
template <typename FlatmmConfig,
typename ADataType,
typename BDataType,
typename DsDatatype,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ck_tile::MoeFlatmmKind kind,
typename CDEElementWise = ck_tile::element_wise::PassThrough,
typename MoeHostArgs>
float invoke_moe_gemm(int n_warmup, int n_repeat, const MoeHostArgs& args)
{
float ave_time = moe_gemm<FlatmmConfig,
ADataType,
BDataType,
DsDatatype,
AccDataType,
CDataType,
ALayout,
BLayout,
DsLayout,
ELayout,
kind,
CDEElementWise>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
std::string op_name{"Moe Gemm"};
std::size_t flop = std::size_t(2) * args.M * args.N * args.K;
std::size_t num_byte = sizeof(ADataType) * args.M * args.K +
sizeof(BDataType) * args.N * args.K +
sizeof(CDataType) * args.M * args.N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
return ave_time;
}
namespace {
struct LocalSilu
{
template <typename T>
CK_TILE_HOST_DEVICE T operator()(const T& x) const
{
T y;
ck_tile::element_wise::Silu{}(y, x);
return y;
};
};
} // namespace
template <typename PrecType,
typename FlatmmConfig,
ck_tile::MoeFlatmmKind kind,
typename ALayout,
typename BLayout,
typename CLayout>
int run_moe_gemm_example_with_layouts(int argc,
char* argv[],
const ALayout a_layout = ALayout{},
const BLayout b_layout = BLayout{},
[[maybe_unused]] const CLayout c_layout = CLayout{})
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
return -1;
};
using ADataType = typename GemmBasicTypeConfig<PrecType>::ADataType;
using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType;
using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType;
using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
constexpr int ScaleGranularityM = 1;
constexpr int ScaleGranularityN = 1;
const ck_tile::index_t N = arg_parser.get_int("N");
const ck_tile::index_t K = arg_parser.get_int("K");
ck_tile::index_t stride_A = arg_parser.get_int("stride_A");
ck_tile::index_t stride_B = arg_parser.get_int("stride_B");
ck_tile::index_t stride_C = arg_parser.get_int("stride_C");
const ck_tile::index_t num_tokens = arg_parser.get_int("NumTokens");
const ck_tile::index_t topk = arg_parser.get_int("TopK");
const ck_tile::index_t warmup = arg_parser.get_int("warmup");
const ck_tile::index_t repeat = arg_parser.get_int("repeat");
const ck_tile::index_t experts = arg_parser.get_int("experts");
// TODO: replace the magic declaration
const ck_tile::index_t MPerBlock = FlatmmConfig::M_Tile;
ck_tile::index_t sorted_tile_num = 8;
ck_tile::index_t valid_tile_num = sorted_tile_num;
ck_tile::index_t sorted_size = sorted_tile_num * MPerBlock;
const ck_tile::index_t M = sorted_tile_num * MPerBlock;
const ck_tile::index_t outputN = kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up ? N / 2 : N;
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
constexpr bool IsInputGemm = kind != ck_tile::MoeFlatmmKind::kFFN_gemm2;
stride_A = ck_tile::get_default_stride(
IsInputGemm ? num_tokens : num_tokens * topk, K, stride_A, is_row_major(a_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
stride_C = ck_tile::get_default_stride(
IsInputGemm ? num_tokens * topk : num_tokens, outputN, stride_C, is_row_major(CLayout{}));
auto a_m_k_tensor = ck_tile::HostTensor<ADataType>(ck_tile::host_tensor_descriptor(
IsInputGemm ? num_tokens : num_tokens * topk, K, stride_A, is_row_major(a_layout)));
auto b_k_n_tensor = ck_tile::HostTensor<BDataType>(
is_row_major(b_layout)
? ck_tile::host_tensor_descriptor(experts * N, K, stride_B, is_row_major(b_layout))
: ck_tile::host_tensor_descriptor(K, experts * N, stride_B, is_row_major(b_layout)));
auto c_m_n_tensor = ck_tile::HostTensor<CDataType>(ck_tile::host_tensor_descriptor(
IsInputGemm ? num_tokens * topk : num_tokens, outputN, stride_C, is_row_major(CLayout{})));
ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_m_k_tensor);
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_k_n_tensor);
auto b_shuffle_host = shuffle_b<FlatmmConfig>(b_k_n_tensor);
std::cout << "moe_flatmm:"
<< "\n num_experts: " << experts << "\n num_tokens: " << num_tokens
<< "\n topk: " << topk << "\n sorted_tile_num: " << sorted_tile_num
<< "\n a_m_k: " << a_m_k_tensor.mDesc << "\n b_k_n: " << b_k_n_tensor.mDesc
<< "\n b_shuffle: " << b_shuffle_host.mDesc << "\n c_m_n: " << c_m_n_tensor.mDesc
<< std::endl;
ck_tile::DeviceMem a_m_k_dev_buf{a_m_k_tensor.get_element_space_size_in_bytes()};
ck_tile::DeviceMem b_origin_dev_buf{b_k_n_tensor.get_element_space_size_in_bytes()};
ck_tile::DeviceMem b_shuffle_dev_buf{b_shuffle_host.get_element_space_size_in_bytes()};
ck_tile::DeviceMem c_m_n_dev_buf{c_m_n_tensor.get_element_space_size_in_bytes()};
a_m_k_dev_buf.ToDevice(a_m_k_tensor.data());
b_origin_dev_buf.ToDevice(b_k_n_tensor.data());
b_shuffle_dev_buf.ToDevice(b_shuffle_host.data());
c_m_n_dev_buf.SetZero();
c_m_n_tensor.SetZero();
const void* p_a = a_m_k_dev_buf.GetDeviceBuffer();
const void* p_b_origin = b_origin_dev_buf.GetDeviceBuffer();
const void* p_b_shuffle = b_shuffle_dev_buf.GetDeviceBuffer();
void* p_c = c_m_n_dev_buf.GetDeviceBuffer();
// TODO: malloc and init sorted tokens and max tokens buffer
ck_tile::HostTensor<ck_tile::index_t> expert_ids(
ck_tile::HostTensorDescriptor({sorted_tile_num}, {1}));
ck_tile::HostTensor<ck_tile::index_t> sorted_token_ids(
ck_tile::HostTensorDescriptor({sorted_size}, {1}));
ck_tile::HostTensor<AccDataType> expert_weight(
ck_tile::HostTensorDescriptor({sorted_size}, {1}));
ck_tile::HostTensor<ck_tile::index_t> max_token_id(
ck_tile::HostTensorDescriptor({1 + sorted_tile_num}));
ck_tile::HostTensor<AccDataType> per_token_scale(
ck_tile::HostTensorDescriptor({IsInputGemm ? num_tokens : M}, {1}));
ck_tile::HostTensor<AccDataType> per_channel_scale(
ck_tile::HostTensorDescriptor({N * experts}, {1}));
ck_tile::FillMonotonicSeq<AccDataType>{}(per_token_scale);
ck_tile::FillUniformDistribution<AccDataType>{0.f, 1.f}(per_token_scale);
ck_tile::FillUniformDistribution<AccDataType>{0.f, 1.f}(per_channel_scale);
// for verification only, no need to satify weight normalization
ck_tile::FillUniformDistribution<AccDataType>{0.0f, 1.0f}(expert_weight);
ck_tile::DeviceMem sorted_token_ids_dev{sizeof(ck_tile::index_t) *
sorted_token_ids.get_element_space_size_in_bytes()};
ck_tile::DeviceMem expert_ids_dev{sizeof(ck_tile::index_t) *
expert_ids.get_element_space_size_in_bytes()};
ck_tile::DeviceMem max_token_id_dev{sizeof(ck_tile::index_t) *
max_token_id.get_element_space_size_in_bytes()};
ck_tile::DeviceMem expert_weight_dev{sizeof(AccDataType) *
expert_weight.get_element_space_size_in_bytes()};
ck_tile::DeviceMem per_token_scale_dev_buf(per_token_scale.get_element_space_size_in_bytes());
ck_tile::DeviceMem per_channel_scale_dev_buf(
per_channel_scale.get_element_space_size_in_bytes());
max_token_id.mData = {valid_tile_num * MPerBlock, 0, 1, 2, 3, 4, 6, 7, 8, 8};
int eids[] = {0, 1, 2, 3, 4, 4, 5, 6, 3, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
for(int i = 0; i < sorted_tile_num; i++)
{
eids[i] = min(eids[i], experts - 1);
expert_ids.mData[i] = eids[i];
}
// int token_per_tile = (num_tokens * topk + valid_tile_num - 1) / valid_tile_num;
int token_per_tile = num_tokens * topk / valid_tile_num;
int tokenid = 0;
// sorted_token_ids.mData[0] = 0;
for(int i = 0; i < sorted_tile_num * MPerBlock; i++)
{
int tile_off = i % MPerBlock;
if(tile_off < token_per_tile && tokenid < num_tokens * topk)
{
sorted_token_ids.mData[i] = (tokenid % num_tokens) | ((tokenid / num_tokens) << 24);
tokenid++;
}
else
{
sorted_token_ids.mData[i] = num_tokens;
}
}
sorted_token_ids_dev.ToDevice(sorted_token_ids.data());
expert_ids_dev.ToDevice(expert_ids.data());
max_token_id_dev.ToDevice(max_token_id.data());
expert_weight_dev.ToDevice(expert_weight.data());
per_token_scale_dev_buf.ToDevice(per_token_scale.data());
per_channel_scale_dev_buf.ToDevice(per_channel_scale.data());
const ck_tile::index_t* p_sorted_token_ids_dev =
static_cast<ck_tile::index_t*>(sorted_token_ids_dev.GetDeviceBuffer());
const ck_tile::index_t* p_expert_ids_dev =
static_cast<ck_tile::index_t*>(expert_ids_dev.GetDeviceBuffer());
const ck_tile::index_t* p_max_token_id_dev =
static_cast<ck_tile::index_t*>(max_token_id_dev.GetDeviceBuffer());
const AccDataType* p_sorted_expert_weight_dev =
static_cast<AccDataType*>(expert_weight_dev.GetDeviceBuffer());
using MoeFlatmmArgs =
ck_tile::MoeFlatmmHostArgs<ck_tile::FlatmmScalePointer<1>, ck_tile::FlatmmScalePointer<1>>;
auto per_token_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityM>{
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer())};
auto per_channel_scale_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityN>{
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer())};
MoeFlatmmArgs gemm_desc{p_sorted_token_ids_dev,
p_sorted_expert_weight_dev,
p_expert_ids_dev,
p_max_token_id_dev,
p_a,
p_b_shuffle,
p_c,
num_tokens,
experts,
topk,
1, // k_batch
M,
N,
K,
stride_A,
stride_B,
stride_C,
per_token_scale_dev_ptr,
per_channel_scale_dev_ptr};
invoke_moe_gemm<FlatmmConfig,
ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ALayout,
BLayout,
ck_tile::tuple<>,
CLayout,
kind>(warmup, repeat, gemm_desc);
c_m_n_dev_buf.FromDevice(c_m_n_tensor.data());
bool pass{true};
if(arg_parser.get_int("validate"))
{
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
ck_tile::host_tensor_descriptor(IsInputGemm ? num_tokens * topk : num_tokens,
outputN,
stride_C,
is_row_major(CLayout{})));
c_m_n_host_ref.SetZero();
std::unique_ptr<ck_tile::DeviceMem> c_m_n_ref_buf =
std::make_unique<ck_tile::DeviceMem>(c_m_n_tensor.get_element_space_size_in_bytes());
c_m_n_ref_buf->SetZero();
ck_tile::reference_moe_gemm_gpu<
ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout,
static_cast<int>(kind),
std::conditional_t<IsInputGemm, LocalSilu, ck_tile::identity>>(
p_sorted_token_ids_dev,
p_expert_ids_dev,
p_max_token_id_dev,
static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b_origin),
static_cast<CDataType*>(c_m_n_ref_buf->GetDeviceBuffer()),
p_sorted_expert_weight_dev,
num_tokens,
MPerBlock,
topk,
M,
N,
K,
stride_A,
stride_B,
stride_C,
1,
1,
K,
static_cast<float*>(per_token_scale_dev_buf.GetDeviceBuffer()),
static_cast<float*>(per_channel_scale_dev_buf.GetDeviceBuffer()));
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, 1 /*kbatch*/, max_accumulated_value);
c_m_n_ref_buf->FromDevice(c_m_n_host_ref.data());
float rtol = 1e-3;
float atol = 1e-3;
pass = ck_tile::check_err(
c_m_n_tensor, c_m_n_host_ref, "Error: Incorrect results!", rtol, atol);
std::cout << "Relative error threshold: " << rtol << " Absolute error threshold: " << atol
<< std::endl;
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
}
return pass;
}

View File

@@ -20,6 +20,7 @@ add_subdirectory(17_grouped_gemm)
add_subdirectory(18_flatmm)
add_subdirectory(19_gemm_multi_d)
add_subdirectory(20_grouped_convolution)
add_subdirectory(21_moe_flatmm)
add_subdirectory(35_batched_transpose)
add_subdirectory(36_copy)
add_subdirectory(37_transpose)