add fp16xf4 moe

This commit is contained in:
Feng Shijie
2025-08-18 17:28:11 +00:00
parent 599e1f5b32
commit be55c0f9cb
10 changed files with 1345 additions and 214 deletions

View File

@@ -1,4 +1,6 @@
add_executable(tile_example_moe_flatmm EXCLUDE_FROM_ALL moe_flatmm.cpp)
add_executable(tile_example_a16w4_moe_flatmm EXCLUDE_FROM_ALL mixed_prec/a16w4_moe_flatmm.cpp)
set(EXAMPLE_MOE_FLATMM_COMPILE_OPTIONS)
@@ -9,3 +11,4 @@ endif()
list(APPEND EXAMPLE_MOE_FLATMM_COMPILE_OPTIONS --save-temps)
target_compile_options(tile_example_moe_flatmm PRIVATE ${EXAMPLE_MOE_FLATMM_COMPILE_OPTIONS})
target_compile_options(tile_example_a16w4_moe_flatmm PRIVATE ${EXAMPLE_MOE_FLATMM_COMPILE_OPTIONS})

View File

@@ -0,0 +1,534 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <ostream>
#include <string>
#include <tuple>
#include <memory>
#include "a16w4_moe_flatmm.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/flatmm.hpp"
#include "ck_tile/ops/moe_flatmm.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host/reference/reference_moe_gemm.hpp"
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
// gemm1
// operand-A = [num_token, d_model]
// operand-B = [num_expert, hidden, d_model]
// operand-C = [num_token, topk, hidden]
// gemm2
// operand-A = [num_token, topk, hidden]
// operand-B = [num_expert, d_model, hidden]
// operand-C = [num_token, d_model]
template <typename FlatmmConfig,
typename ADataType,
typename BDataType,
typename DsDatatype,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ck_tile::MoeFlatmmKind moe_kind = ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only,
typename CDEElementWise = ck_tile::element_wise::PassThrough,
typename ScaleM,
typename ScaleN>
float a16w4_moe_gemm(const ck_tile::MoeFlatmmHostArgs<ScaleM, ScaleN>& args,
const ck_tile::stream_config& s)
{
using CodegenFlatmmShape = ck_tile::TileGemmShape<
ck_tile::sequence<FlatmmConfig::M_Tile, FlatmmConfig::N_Tile, FlatmmConfig::K_Tile>,
ck_tile::sequence<FlatmmConfig::M_Warp, FlatmmConfig::N_Warp, FlatmmConfig::K_Warp>,
ck_tile::sequence<FlatmmConfig::M_Warp_Tile,
FlatmmConfig::N_Warp_Tile,
FlatmmConfig::K_Warp_Tile>>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<CodegenFlatmmShape,
FlatmmConfig::TileParitionerGroupNum,
FlatmmConfig::TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<FlatmmConfig::kPadM,
FlatmmConfig::kPadN,
FlatmmConfig::kPadK,
ALayout,
BLayout,
ELayout,
FlatmmConfig::NumWaveGroups>;
using CodegenGemmTraits = ck_tile::TileGemmUniversalTraits<FlatmmConfig::kPadM,
FlatmmConfig::kPadN,
FlatmmConfig::kPadK,
FlatmmConfig::DoubleSmemBuffer,
ALayout,
BLayout,
ELayout,
FlatmmConfig::TransposeC,
FlatmmConfig::UseStructuredSparsity,
false, // UsePersistentKernel_
FlatmmConfig::NumWaveGroups,
true>; // Preshuffle_
constexpr bool MXFP4_Pipeline = std::is_same_v<BDataType, ck_tile::pk_fp4_t>;
if constexpr(!MXFP4_Pipeline && moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up)
{
static_assert(
FlatmmConfig::N_Tile % (FlatmmConfig::N_Warp * FlatmmConfig::N_Warp_Tile * 2) == 0,
"requires NRepeat is multiple of 2 for FFN_gemm1_gate_up");
}
using ComputeDataType = ADataType;
static_assert(sizeof(ComputeDataType) >= sizeof(BDataType),
"mixed_prec_flatmm requires ADataType is a wider type than BDataType");
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<ComputeDataType,
ComputeDataType,
AccDataType,
CodegenFlatmmShape,
Traits>;
using BaseGemmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1<GemmPipelineProblem>;
const ck_tile::index_t k_grain = args.k_batch * FlatmmConfig::K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * FlatmmConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = FlatmmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using CodegenPipelineProblem =
std::conditional_t<MXFP4_Pipeline,
ck_tile::F16xMXF4FlatmmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenFlatmmShape,
CodegenGemmTraits,
scheduler,
has_hot_loop_v,
tail_number_v>,
ck_tile::FlatmmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenFlatmmShape,
CodegenGemmTraits,
scheduler,
has_hot_loop_v,
tail_number_v>>;
constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ComputeDataType,
ComputeDataType,
DsDatatype,
AccDataType,
CDataType,
DsLayout,
ELayout,
CDEElementWise,
CodegenPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
FlatmmConfig::M_Warp,
FlatmmConfig::N_Warp,
FlatmmConfig::M_Warp_Tile,
FlatmmConfig::N_Warp_Tile,
FlatmmConfig::K_Warp_Tile,
CodegenPipelineProblem::TransposeC,
memory_operation,
FlatmmConfig::NumWaveGroups,
false,
1,
FlatmmConfig::TiledMMAPermuteN,
BlockedXDLN_PerWarp>>;
using CodegenFlatmmPipeline = std::conditional_t<
MXFP4_Pipeline,
ck_tile::F16xMXF4FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>,
ck_tile::MoeFlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>>;
using Kernel = ck_tile::
MoeFlatmmKernel<TilePartitioner, CodegenFlatmmPipeline, GemmEpilogue, moe_kind>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(kargs);
constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:" << CodegenFlatmmShape::GetName() << "\n"
<< "Shape: " << CodegenFlatmmShape::GetName() << "\n"
<< "problem: " << CodegenPipelineProblem::GetName() << "\n"
<< "pipeline: " << CodegenFlatmmPipeline::GetName() << "\n"
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
if(s.flush_cache_)
{
std::cout << "Flushing cache..." << std::endl;
static constexpr ck_tile::index_t APackedSize =
std::is_same_v<BDataType, ck_tile::pk_int4_t> ||
std::is_same_v<BDataType, ck_tile::pk_fp4_t>
? 2
: 1;
static constexpr ck_tile::index_t BPackedSize =
std::is_same_v<BDataType, ck_tile::pk_int4_t> ||
std::is_same_v<BDataType, ck_tile::pk_fp4_t>
? 2
: 1;
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm2 ? args.NumTokens * args.TopK
: args.NumTokens,
args.K,
args.stride_A,
is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
args.K, args.N * args.NumExperts, args.stride_B, is_row_major(BLayout{})));
const int outputN =
moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up ? args.N / 2 : args.N;
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck_tile::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if(moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm2)
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.NumTokens * args.N * sizeof(CDataType), s.stream_id_));
else if(args.k_batch > 1)
hipGetErrorString(
hipMemsetAsync(args.e_ptr,
0,
args.NumTokens * args.TopK * outputN * sizeof(CDataType),
s.stream_id_));
};
ave_time = ck_tile::launch_kernel_preprocess(
s,
run_flush_cache,
ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
}
else
{
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
}
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(args.k_batch == 1)
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
return ave_time;
}
template <class FlatmmConfig, ck_tile::MoeFlatmmKind moe_kind, class IterSrc, class IterDst>
void shuffle_mxfp4_weight(const IterSrc src, IterDst dst, int experts_cnt, int N, int K)
{
int KPack = 16;
int NLane = FlatmmConfig::N_Warp_Tile;
int KLane = 64 / NLane;
int K_pk = K / 2;
int K0 = K_pk / (KLane * KPack);
// K -> K0 KLane KPack
// N -> N0 NLane
// N, K -> N0 K0 KLane NLane KPack
int tempk;
if constexpr(moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up)
{
int up_stride = N / 2 / NLane;
for(int eid = 0; eid < experts_cnt; ++eid)
{
for(int n = 0; n < N; ++n)
{
for(int k = 0; k < K_pk; ++k)
{
int n0 = n / NLane;
int n1 = n % NLane;
// interleave gate and up part with granularity is 16.
int n0_interleave = n >= N / 2 ? (n0 - up_stride) * 2 + 1 : // up part
n0 * 2; // gate part
int k0 = k / (KLane * KPack);
tempk = k % (KLane * KPack);
int k1 = tempk / KPack;
int k2 = tempk % KPack;
int outputIndex = eid * N * K_pk + n0_interleave * KPack * NLane * KLane * K0 +
k0 * KPack * NLane * KLane + k1 * KPack * NLane + n1 * KPack +
k2;
dst[outputIndex] = src[eid * N * K_pk + n * K_pk + k];
}
}
}
}
else
{
for(int eid = 0; eid < experts_cnt; ++eid)
{
for(int n = 0; n < N; ++n)
{
for(int k = 0; k < K_pk; ++k)
{
int n0 = n / NLane;
int n1 = n % NLane;
int k0 = k / (KLane * KPack);
tempk = k % (KLane * KPack);
int k1 = tempk / KPack;
int k2 = tempk % KPack;
int outputIndex = eid * N * K_pk + n0 * KPack * NLane * KLane * K0 +
k0 * KPack * NLane * KLane + k1 * KPack * NLane + n1 * KPack +
k2;
dst[outputIndex] = src[eid * N * K_pk + n * K_pk + k];
}
}
}
}
}
template <typename FlatmmConfig, ck_tile::MoeFlatmmKind moe_kind, typename T>
auto shuffle_mxfp4_scale(const ck_tile::HostTensor<T>& scale, int experts_cnt)
{
assert(scale.get_lengths().size() == 2);
int n_ = scale.get_lengths()[1];
int k_ = scale.get_lengths()[0];
int k_per_expert = k_ / experts_cnt;
constexpr int K_Pack = 2; // fixed for mxfp4
constexpr int N_Pack = 2; // fixed for mxfp4
constexpr int GranularityK = 32; // fixed for mxfp4
constexpr int K_Lane = 64 / FlatmmConfig::N_Warp_Tile; // 4
static_assert(FlatmmConfig::N_Warp_Tile == 16, "only support XDL_N == 16");
static_assert(FlatmmConfig::N_Repeat % N_Pack == 0);
static_assert(FlatmmConfig::K_Tile % (K_Pack * K_Lane * GranularityK) == 0);
if constexpr(moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up)
{
ck_tile::HostTensor<T> shfl_scale({
experts_cnt,
k_per_expert / K_Pack / K_Lane,
K_Pack,
K_Lane,
N_Pack, // N_Pack = 2 is composed of Gate + Up.
n_ / FlatmmConfig::N_Warp_Tile / N_Pack,
FlatmmConfig::N_Warp_Tile,
});
std::copy(scale.begin(), scale.end(), shfl_scale.begin());
return ck_tile::reference_permute(shfl_scale, {0, 5, 1, 3, 6, 2, 4});
}
else
{
ck_tile::HostTensor<T> shfl_scale({
experts_cnt,
k_per_expert / K_Pack / K_Lane,
K_Pack,
K_Lane,
n_ / FlatmmConfig::N_Warp_Tile / N_Pack,
N_Pack,
FlatmmConfig::N_Warp_Tile,
});
std::copy(scale.begin(), scale.end(), shfl_scale.begin());
return ck_tile::reference_permute(shfl_scale, {0, 4, 1, 3, 6, 2, 5});
}
}
#include "run_a16w4_moe_flatmm_example.inc"
template <typename FlatmmConfig>
int run_a16w4_moe_flatmm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
return -1;
}
const std::string a_layout = arg_parser.get_str("a_layout");
const std::string b_layout = arg_parser.get_str("b_layout");
const std::string mixed_prec = arg_parser.get_str("mixed_prec");
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
if(a_layout == "R" && b_layout == "C")
{
const std::string gemm_kind = arg_parser.get_str("gemm_kind");
if(gemm_kind == "gemm1_gate_up")
{
if(mixed_prec == "fp16xfp4")
{
return run_a16w4_moe_gemm_example_with_layouts<
ck_tile::half_t,
ck_tile::pk_fp4_t,
FlatmmConfig,
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
}
else if(mixed_prec == "bf16xfp4")
{
return run_a16w4_moe_gemm_example_with_layouts<
ck_tile::bfloat16_t,
ck_tile::pk_fp4_t,
FlatmmConfig,
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported precision type for gemm1_gate_up!");
}
}
else if(gemm_kind == "gemm1_gate_only")
{
if(mixed_prec == "fp16xfp4")
{
return run_a16w4_moe_gemm_example_with_layouts<
ck_tile::half_t,
ck_tile::pk_fp4_t,
FlatmmConfig,
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only>(argc, argv, Row{}, Col{}, Row{});
}
else if(mixed_prec == "bf16xfp4")
{
return run_a16w4_moe_gemm_example_with_layouts<
ck_tile::bfloat16_t,
ck_tile::pk_fp4_t,
FlatmmConfig,
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only>(argc, argv, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported precision type for gemm1_gate_only!");
}
}
else if(gemm_kind == "gemm2")
{
if(mixed_prec == "fp16xfp4")
{
return run_a16w4_moe_gemm_example_with_layouts<ck_tile::half_t,
ck_tile::pk_fp4_t,
FlatmmConfig,
ck_tile::MoeFlatmmKind::kFFN_gemm2>(
argc, argv, Row{}, Col{}, Row{});
}
else if(mixed_prec == "bf16xfp4")
{
return run_a16w4_moe_gemm_example_with_layouts<ck_tile::bfloat16_t,
ck_tile::pk_fp4_t,
FlatmmConfig,
ck_tile::MoeFlatmmKind::kFFN_gemm2>(
argc, argv, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported precision type for gemm2!");
}
}
else
{
throw std::runtime_error("Unrecoginized gemm_kind parameter, only accept value "
"[gemm1_gate_only | gemm1_gate_up | gemm2]");
}
}
else
{
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
}
return -1;
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return EXIT_FAILURE;
try
{
int warp_tile = arg_parser.get_int("warp_tile");
if(warp_tile == 0)
{
return !run_a16w4_moe_flatmm_example<A16W4_FlatmmConfig16>(argc, argv);
}
// else if(warp_tile == 1)
// {
// return !run_a16w4_moe_flatmm_example<A16W4_FlatmmConfig16_950>(argc, argv);
// }
}
catch(const std::runtime_error& e)
{
std::cerr << "Runtime error: " << e.what() << '\n';
return EXIT_FAILURE;
}
}

View File

@@ -0,0 +1,86 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include <tuple>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/moe_flatmm.hpp"
// GEMM config with 16x16 warp tile
struct A16W4_FlatmmConfig16
{
static constexpr ck_tile::index_t M_Tile = 64;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 32;
static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr int kBlockPerCu = 1;
static constexpr int TileParitionerGroupNum = 8;
static constexpr int TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool DoubleSmemBuffer = false;
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
static constexpr bool TiledMMAPermuteN = false;
};
struct A16W4_FlatmmConfig16_950 : public A16W4_FlatmmConfig16
{
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr int kBlockPerCu = 1;
static constexpr int N_Repeat =
N_Tile / A16W4_FlatmmConfig16::N_Warp_Tile / A16W4_FlatmmConfig16::N_Warp;
static constexpr bool TiledMMAPermuteN = false;
};
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("experts", "8", "Num of experts - 8 by default")
.insert("NumTokens", "128", "M dimensions - 128 by default.")
.insert("TopK", "3", "Top K - 3 by default.")
.insert("N", "4096", "N dimensions - 4096 by default.")
.insert("K", "4096", "K dimensions - 4096 by default.")
.insert("stride_A", "", "Tensor A strides - it is empty by default.")
.insert("stride_B", "", "Tensor B strides - it is empty by default.")
.insert("stride_C", "", "Tensor C strides - it is empty by default.")
.insert("a_layout", "R", "A tensor data layout - Row by default.")
.insert("b_layout", "C", "B tensor data layout - Col by default.")
.insert("c_layout", "R", "C tensor data layout - Row by default.")
.insert("gemm_kind",
"gemm1_gate_only",
"Gemm kind in FFN network [gemm1_gate_only | gemm1_gate_up | gemm2] - "
"gemm1_gate_only by default.")
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("mixed_prec",
"bf16xfp4",
"data type for activation and weight, support: bf16xfp4, fp16xfp4")
.insert("init", "0", "0:random, 1:constant(1)")
.insert(
"warp_tile", "0", "0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)")
.insert("repeat", "10", "number of iterations to benchmark the kernel.");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}

View File

@@ -0,0 +1,365 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
template <typename FlatmmConfig,
typename ADataType,
typename BDataType,
typename DsDatatype,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
ck_tile::MoeFlatmmKind kind,
typename CDEElementWise = ck_tile::element_wise::PassThrough,
typename MoeHostArgs>
float invoke_a16w4_moe_gemm(int n_warmup, int n_repeat, const MoeHostArgs& args)
{
float ave_time = a16w4_moe_gemm<FlatmmConfig,
ADataType,
BDataType,
DsDatatype,
AccDataType,
CDataType,
ALayout,
BLayout,
DsLayout,
ELayout,
kind,
CDEElementWise>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
std::string op_name{"Moe Gemm"};
constexpr int PackedSize = ck_tile::numeric_traits<BDataType>::PackedSize;
std::size_t flop = std::size_t(2) * args.M * args.N * args.K;
std::size_t num_byte = sizeof(ADataType) * args.M * args.K +
sizeof(BDataType) * args.N * args.K / PackedSize +
sizeof(CDataType) * args.M * args.N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
<< gb_per_sec << " GB/s, " << op_name << std::endl;
return ave_time;
}
namespace {
struct LocalSilu
{
template <typename T>
CK_TILE_HOST_DEVICE T operator()(const T& x) const
{
T y;
ck_tile::element_wise::Silu{}(y, x);
return y;
};
};
} // namespace
template <typename PrecActType,
typename PrecWeightType,
typename FlatmmConfig,
ck_tile::MoeFlatmmKind kind,
typename ALayout,
typename BLayout,
typename CLayout>
int run_a16w4_moe_gemm_example_with_layouts(int argc,
char* argv[],
const ALayout a_layout = ALayout{},
const BLayout b_layout = BLayout{},
[[maybe_unused]] const CLayout c_layout = CLayout{})
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
return -1;
};
using ADataType = PrecActType;
using BDataType = PrecWeightType;
using CDataType = PrecActType;
using AccDataType = float;
using ScaleType = ck_tile::e8m0_t;
constexpr int ScaleGranularityN = 1;
constexpr int ScaleGranularityK = 32;
const ck_tile::index_t N = arg_parser.get_int("N");
const ck_tile::index_t K = arg_parser.get_int("K");
ck_tile::index_t stride_A = arg_parser.get_int("stride_A");
ck_tile::index_t stride_B = arg_parser.get_int("stride_B");
ck_tile::index_t stride_C = arg_parser.get_int("stride_C");
ck_tile::index_t init_method = arg_parser.get_int("init");
const ck_tile::index_t num_tokens = arg_parser.get_int("NumTokens");
const ck_tile::index_t topk = arg_parser.get_int("TopK");
const ck_tile::index_t warmup = arg_parser.get_int("warmup");
const ck_tile::index_t repeat = arg_parser.get_int("repeat");
const ck_tile::index_t experts = arg_parser.get_int("experts");
// TODO: replace the magic declaration
const ck_tile::index_t MPerBlock = FlatmmConfig::M_Tile;
ck_tile::index_t sorted_tile_num = num_tokens * topk / MPerBlock;
ck_tile::index_t valid_tile_num = sorted_tile_num;
ck_tile::index_t sorted_size = sorted_tile_num * MPerBlock;
const ck_tile::index_t M = sorted_tile_num * MPerBlock;
const ck_tile::index_t outputN = kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up ? N / 2 : N;
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
constexpr bool IsInputGemm = kind != ck_tile::MoeFlatmmKind::kFFN_gemm2;
stride_A = ck_tile::get_default_stride(
IsInputGemm ? num_tokens : num_tokens * topk, K, stride_A, is_row_major(a_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
stride_C = ck_tile::get_default_stride(
IsInputGemm ? num_tokens * topk : num_tokens, outputN, stride_C, is_row_major(CLayout{}));
auto a_m_k_tensor = ck_tile::HostTensor<ADataType>(ck_tile::host_tensor_descriptor(
IsInputGemm ? num_tokens : num_tokens * topk, K, stride_A, is_row_major(a_layout)));
auto b_k_n_tensor = ck_tile::HostTensor<BDataType>(
is_row_major(b_layout)
? ck_tile::host_tensor_descriptor(experts * N, K, stride_B, is_row_major(b_layout))
: ck_tile::host_tensor_descriptor(K, experts * N, stride_B, is_row_major(b_layout)));
auto c_m_n_tensor = ck_tile::HostTensor<CDataType>(ck_tile::host_tensor_descriptor(
IsInputGemm ? num_tokens * topk : num_tokens, outputN, stride_C, is_row_major(CLayout{})));
ck_tile::HostTensor<ScaleType> scale_b(ck_tile::HostTensorDescriptor(
{K * experts / ScaleGranularityK, N / ScaleGranularityN}, {N / ScaleGranularityN, 1}));
if(init_method == 0)
{
ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_m_k_tensor);
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_k_n_tensor);
ck_tile::FillUniformDistribution<ScaleType>{0.f, 1.f}(scale_b);
}
else
{
ck_tile::FillUniformDistribution<ADataType>{1.0f, 1.0f}(a_m_k_tensor);
ck_tile::FillUniformDistribution<BDataType>{1.0f, 1.0f}(b_k_n_tensor);
ck_tile::FillUniformDistribution<ScaleType>{1.0f, 1.0f}(scale_b);
}
ck_tile::HostTensor<BDataType> b_shuffle_host(
ck_tile::host_tensor_descriptor(K, experts * N, stride_B, is_row_major(b_layout)));
shuffle_mxfp4_weight<FlatmmConfig, kind>(
b_k_n_tensor.begin(), b_shuffle_host.begin(), experts, N, K);
ck_tile::HostTensor<ScaleType> scale_b_shuffle =
shuffle_mxfp4_scale<FlatmmConfig, kind>(scale_b, experts);
ck_tile::DeviceMem scale_b_shuffle_dev_buf(scale_b_shuffle.get_element_space_size_in_bytes());
std::cout << "moe_flatmm:"
<< "\n num_experts: " << experts << "\n num_tokens: " << num_tokens
<< "\n topk: " << topk << "\n sorted_tile_num: " << sorted_tile_num
<< "\n problem_n: " << N << "\n problem_k: " << K
<< "\n a_m_k: " << a_m_k_tensor.mDesc << "\n b_k_n: " << b_k_n_tensor.mDesc
<< "\n b_shuffle: " << b_shuffle_host.mDesc << "\n c_m_n: " << c_m_n_tensor.mDesc
<< std::endl;
ck_tile::HostTensor<ck_tile::index_t> expert_ids(
ck_tile::HostTensorDescriptor({sorted_tile_num}, {1}));
ck_tile::HostTensor<ck_tile::index_t> sorted_token_ids(
ck_tile::HostTensorDescriptor({sorted_size}, {1}));
ck_tile::HostTensor<AccDataType> expert_weight(
ck_tile::HostTensorDescriptor({sorted_size}, {1}));
ck_tile::HostTensor<ck_tile::index_t> max_token_id(
ck_tile::HostTensorDescriptor({1 + sorted_tile_num}));
if(init_method == 0)
{
// for verification only, no need to satify weight normalization
ck_tile::FillUniformDistribution<AccDataType>{0.0f, 1.0f}(expert_weight);
}
else
{
ck_tile::FillUniformDistribution<AccDataType>{1.0f, 1.0f}(expert_weight);
}
max_token_id.mData = {valid_tile_num * MPerBlock, 0, 1, 2, 3, 4, 6, 7, 8, 8};
// int eids[] = {0, 1, 2, 3, 4, 4, 5, 6, 3, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
for(int i = 0; i < sorted_tile_num; i++)
{
expert_ids.mData[i] = i / ((valid_tile_num + experts - 1) / experts);
}
int token_per_tile = (num_tokens * topk + valid_tile_num - 1) / valid_tile_num;
// int token_per_tile = num_tokens * topk / valid_tile_num;
int tokenid = 0;
// sorted_token_ids.mData[0] = 0;
for(int i = 0; i < sorted_tile_num * MPerBlock; i++)
{
int tile_off = i % MPerBlock;
if(tile_off < token_per_tile && tokenid < num_tokens * topk)
{
sorted_token_ids.mData[i] = (tokenid % num_tokens) | ((tokenid / num_tokens) << 24);
tokenid++;
}
else
{
sorted_token_ids.mData[i] = num_tokens;
}
}
ck_tile::DeviceMem a_m_k_dev_buf{a_m_k_tensor.get_element_space_size_in_bytes()};
ck_tile::DeviceMem b_origin_dev_buf{b_k_n_tensor.get_element_space_size_in_bytes()};
ck_tile::DeviceMem b_shuffle_dev_buf{b_shuffle_host.get_element_space_size_in_bytes()};
ck_tile::DeviceMem c_m_n_dev_buf{c_m_n_tensor.get_element_space_size_in_bytes()};
a_m_k_dev_buf.ToDevice(a_m_k_tensor.data());
b_origin_dev_buf.ToDevice(b_k_n_tensor.data());
b_shuffle_dev_buf.ToDevice(b_shuffle_host.data());
c_m_n_dev_buf.SetZero();
c_m_n_tensor.SetZero();
ck_tile::DeviceMem sorted_token_ids_dev{sizeof(ck_tile::index_t) *
sorted_token_ids.get_element_space_size_in_bytes()};
ck_tile::DeviceMem expert_ids_dev{sizeof(ck_tile::index_t) *
expert_ids.get_element_space_size_in_bytes()};
ck_tile::DeviceMem max_token_id_dev{sizeof(ck_tile::index_t) *
max_token_id.get_element_space_size_in_bytes()};
ck_tile::DeviceMem expert_weight_dev{sizeof(AccDataType) *
expert_weight.get_element_space_size_in_bytes()};
sorted_token_ids_dev.ToDevice(sorted_token_ids.data());
expert_ids_dev.ToDevice(expert_ids.data());
max_token_id_dev.ToDevice(max_token_id.data());
expert_weight_dev.ToDevice(expert_weight.data());
scale_b_shuffle_dev_buf.ToDevice(scale_b_shuffle.data());
const ck_tile::index_t* p_sorted_token_ids_dev =
static_cast<ck_tile::index_t*>(sorted_token_ids_dev.GetDeviceBuffer());
const ck_tile::index_t* p_expert_ids_dev =
static_cast<ck_tile::index_t*>(expert_ids_dev.GetDeviceBuffer());
const ck_tile::index_t* p_max_token_id_dev =
static_cast<ck_tile::index_t*>(max_token_id_dev.GetDeviceBuffer());
const AccDataType* p_sorted_expert_weight_dev =
static_cast<AccDataType*>(expert_weight_dev.GetDeviceBuffer());
auto scale_b_shuffle_dev_ptr =
ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK>{
static_cast<float*>(scale_b_shuffle_dev_buf.GetDeviceBuffer()), N / ScaleGranularityN};
using MoeFlatmmArgs = ck_tile::MoeFlatmmHostArgs<
ck_tile::FlatmmScalePointer<-1>,
ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK>>;
MoeFlatmmArgs gemm_desc{p_sorted_token_ids_dev,
p_sorted_expert_weight_dev,
p_expert_ids_dev,
p_max_token_id_dev,
a_m_k_dev_buf.GetDeviceBuffer(),
b_shuffle_dev_buf.GetDeviceBuffer(),
c_m_n_dev_buf.GetDeviceBuffer(),
num_tokens,
experts,
topk,
1, // k_batch
M,
N,
K,
stride_A,
stride_B,
stride_C,
nullptr,
scale_b_shuffle_dev_ptr};
invoke_a16w4_moe_gemm<FlatmmConfig,
ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ALayout,
BLayout,
ck_tile::tuple<>,
CLayout,
kind>(warmup, repeat, gemm_desc);
c_m_n_dev_buf.FromDevice(c_m_n_tensor.data());
bool pass{true};
if(arg_parser.get_int("validate"))
{
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
ck_tile::host_tensor_descriptor(IsInputGemm ? num_tokens * topk : num_tokens,
outputN,
stride_C,
is_row_major(CLayout{})));
c_m_n_host_ref.SetZero();
ck_tile::HostTensor<AccDataType> scale_A(
ck_tile::HostTensorDescriptor({1, K / ScaleGranularityK}, {1, 1}));
// scaleA = 1 has no effect on the result
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(scale_A);
ck_tile::DeviceMem scale_A_dev_buf(scale_A.get_element_space_size_in_bytes());
scale_A_dev_buf.ToDevice(scale_A.data());
// convert scale_b from e8m0 to float
ck_tile::HostTensor<AccDataType> scale_b_float(ck_tile::HostTensorDescriptor(
{K * experts / ScaleGranularityK, N / ScaleGranularityN}, {N / ScaleGranularityN, 1}));
std::copy(scale_b.begin(), scale_b.end(), scale_b_float.begin());
ck_tile::DeviceMem scale_b_float_dev_buf(scale_b_float.get_element_space_size_in_bytes());
scale_b_float_dev_buf.ToDevice(scale_b_float.data());
std::unique_ptr<ck_tile::DeviceMem> c_m_n_ref_buf =
std::make_unique<ck_tile::DeviceMem>(c_m_n_tensor.get_element_space_size_in_bytes());
c_m_n_ref_buf->SetZero();
ck_tile::reference_moe_gemm_gpu<
ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout,
static_cast<int>(kind),
std::conditional_t<IsInputGemm, LocalSilu, ck_tile::identity>>(
p_sorted_token_ids_dev,
p_expert_ids_dev,
p_max_token_id_dev,
static_cast<const ADataType*>(a_m_k_dev_buf.GetDeviceBuffer()),
static_cast<const BDataType*>(b_origin_dev_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_m_n_ref_buf->GetDeviceBuffer()),
p_sorted_expert_weight_dev,
num_tokens,
MPerBlock,
topk,
M,
N,
K,
stride_A,
stride_B,
stride_C,
M,
1,
ScaleGranularityK,
static_cast<float*>(scale_A_dev_buf.GetDeviceBuffer()),
static_cast<float*>(scale_b_float_dev_buf.GetDeviceBuffer()));
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
c_m_n_ref_buf->FromDevice(c_m_n_host_ref.data());
const float rtol = std::is_same_v<ADataType, ck_tile::half_t> && IsInputGemm ? 1e-3 : 1e-2;
const float atol = std::is_same_v<ADataType, ck_tile::half_t> && IsInputGemm ? 1e-3 : 1e-2;
pass = ck_tile::check_err(
c_m_n_tensor, c_m_n_host_ref, "Error: Incorrect results!", rtol, atol);
std::cout << "Relative error threshold: " << rtol << " Absolute error threshold: " << atol
<< std::endl;
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
}
return pass;
}

View File

@@ -165,6 +165,10 @@ float moe_gemm(const ck_tile::MoeFlatmmHostArgs<ScaleM, ScaleN>& args,
has_hot_loop_v,
tail_number_v>;
constexpr int BlockedXDLN_PerWarp = moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up
? 2
: 1; // determined by scale shuffle pattern
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
@@ -187,7 +191,8 @@ float moe_gemm(const ck_tile::MoeFlatmmHostArgs<ScaleM, ScaleN>& args,
FlatmmConfig::NumWaveGroups,
false,
1,
FlatmmConfig::TiledMMAPermuteN>>;
FlatmmConfig::TiledMMAPermuteN,
BlockedXDLN_PerWarp>>;
using CodegenFlatmmPipeline =
ck_tile::MoeFlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;

View File

@@ -170,7 +170,6 @@ int run_moe_gemm_example_with_layouts(int argc,
ck_tile::HostTensor<AccDataType> per_channel_scale(
ck_tile::HostTensorDescriptor({N * experts}, {1}));
ck_tile::FillMonotonicSeq<AccDataType>{}(per_token_scale);
ck_tile::FillUniformDistribution<AccDataType>{0.f, 1.f}(per_token_scale);
ck_tile::FillUniformDistribution<AccDataType>{0.f, 1.f}(per_channel_scale);
@@ -195,12 +194,12 @@ int run_moe_gemm_example_with_layouts(int argc,
for(int i = 0; i < sorted_tile_num; i++)
{
expert_ids.mData[i] = i / (valid_tile_num / experts);
expert_ids.mData[i] = i / ((valid_tile_num + experts - 1) / experts);
}
int token_per_tile = (num_tokens * topk + valid_tile_num - 1) / valid_tile_num;
// int token_per_tile = num_tokens * topk / valid_tile_num;
int tokenid = 0;
int tokenid = 0;
// sorted_token_ids.mData[0] = 0;
for(int i = 0; i < sorted_tile_num * MPerBlock; i++)
{
@@ -329,8 +328,8 @@ int run_moe_gemm_example_with_layouts(int argc,
K, 1 /*kbatch*/, max_accumulated_value);
c_m_n_ref_buf->FromDevice(c_m_n_host_ref.data());
float rtol = 1e-3;
float atol = 1e-3;
const float rtol = std::is_same_v<ADataType, ck_tile::half_t> && IsInputGemm ? 1e-3 : 1e-2;
const float atol = std::is_same_v<ADataType, ck_tile::half_t> && IsInputGemm ? 1e-3 : 1e-2;
pass = ck_tile::check_err(
c_m_n_tensor, c_m_n_host_ref, "Error: Incorrect results!", rtol, atol);

View File

@@ -82,11 +82,13 @@ __global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
AccDataType acc_temp = 0.0;
AccDataType acc_up_temp = 0.0;
float scale_A = 0;
float scale_B = 0;
float scale_B_up = 0;
index_t scale_A_stride = (M + scale_granularity_m - 1) / scale_granularity_m;
index_t scale_B_stride = (N + scale_granularity_n - 1) / scale_granularity_n;
float scale_A = 0;
float scale_B = 0;
float scale_B_up = 0;
index_t scale_A_stride = (M + scale_granularity_m - 1) / scale_granularity_m;
index_t scale_B_stride = (N + scale_granularity_n - 1) / scale_granularity_n;
index_t scale_B_expert_stride = scale_B_stride * K / scale_granularity_k;
for(int k = 0; k < K; ++k)
{
@@ -101,12 +103,13 @@ __global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
// update scale factors
scale_A = scale_A_ptr[(gather_token_id / scale_granularity_m) +
(k / scale_granularity_k) * scale_A_stride];
scale_B = scale_B_ptr[((expert_id * N + col) / scale_granularity_n) +
(k / scale_granularity_k) * scale_B_stride];
scale_B =
scale_B_ptr[expert_id * scale_B_expert_stride + col / scale_granularity_n +
(k / scale_granularity_k) * scale_B_stride];
if constexpr(MoeGemmKind == 1)
scale_B_up =
scale_B_ptr[((expert_id * N + col + problem_N) / scale_granularity_n) +
(k / scale_granularity_k) * scale_B_stride];
scale_B_up = scale_B_ptr[expert_id * scale_B_expert_stride +
(col + problem_N) / scale_granularity_n +
(k / scale_granularity_k) * scale_B_stride];
}
constexpr index_t packed_size_a = ck_tile::numeric_traits<ADataType>::PackedSize;
@@ -138,6 +141,14 @@ __global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
else
v_a = fp32_val.lo;
}
else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
{
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a]);
if(k % 2 == 1)
v_a = fp32_val.hi;
else
v_a = fp32_val.lo;
}
else
{
v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
@@ -159,6 +170,22 @@ __global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
v_b_up = fp32_val_up.lo;
}
}
else if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
{
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(B[b_index / packed_size_b]);
if(k % 2 == 1)
v_b = fp32_val.hi;
else
v_b = fp32_val.lo;
if constexpr(MoeGemmKind == 1)
{
const fp32x2_t fp32_val_up = pk_fp4_to_fp32x2(B[b_index_up / packed_size_b]);
if(k % 2 == 1)
v_b_up = fp32_val_up.hi;
else
v_b_up = fp32_val_up.lo;
}
}
else
{
v_b = ck_tile::type_convert<AccDataType>(B[b_index]);

View File

@@ -370,6 +370,14 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Ep
const auto& d_block_window = gemm_tile_windows.at(I2);
const auto& scale_block_window = gemm_tile_windows.at(I4);
static_assert(ScaleM::GranularityK == ScaleN::GranularityK // have the same granK
|| ScaleM::GranularityMN == -1 // or ScaleA is disable
|| ScaleN::GranularityMN == -1, // or ScaleB is disable
"ScaleM and ScaleN should have the same GranularityK");
constexpr bool DoEpiScale =
(ScaleM::GranularityMN != -1 && ScaleM::GranularityK == 0) || // per token
(ScaleN::GranularityMN != -1 && ScaleN::GranularityK == 0); // per channel
auto a_block_window_with_distr =
ck_tile::make_tile_window(a_block_window.get_bottom_tensor_view(),
a_block_window.get_window_lengths(),
@@ -383,26 +391,21 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Ep
smem_ptr_pong);
// Run Epilogue Pipeline
if constexpr(false && (ScaleM::GranularityMN != -1 && ScaleM::GranularityK == 0) ||
(ScaleN::GranularityMN != -1 && ScaleN::GranularityK == 0))
if constexpr(DoEpiScale)
{
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}.template
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
c_block_window,
c_block_tile,
d_block_window,
smem_ptr_ping,
kargs.scale_m_ptr + block_idx_m,
kargs.scale_n_ptr + block_idx_n);
EpiloguePipeline{}(c_block_window,
c_block_tile,
d_block_window,
smem_ptr_ping,
kargs.scale_m_ptr + block_idx_m,
kargs.scale_n_ptr + block_idx_n);
}
else if(UseDefaultScheduler || (get_warp_id() == 0))
{
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}.template
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
}
}

View File

@@ -80,7 +80,8 @@ enum class MoeFlatmmKind
template <typename TilePartitioner_,
typename FlatmmPipeline_,
typename EpiloguePipeline_,
MoeFlatmmKind kind>
MoeFlatmmKind kind,
typename FusedActivation = element_wise::Silu>
struct MoeFlatmmKernel
{
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
@@ -101,7 +102,8 @@ struct MoeFlatmmKernel
// Below type is actually accumulation data type - the output of block GEMM.
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
using AccDataType = float;
using AccDataType = float;
using ActivationOp = FusedActivation;
static constexpr index_t NumDTensor = DsDataType::size();
@@ -114,6 +116,7 @@ struct MoeFlatmmKernel
"The size of DsLayout and DsDataType should be the same");
static constexpr bool IsInputGemm = kind != MoeFlatmmKind::kFFN_gemm2;
static constexpr bool IsGateUp = kind == MoeFlatmmKind::kFFN_gemm1_gate_up;
static constexpr index_t kBlockSize = EpiloguePipeline::kBlockSize;
static constexpr index_t kMPerBlock = EpiloguePipeline::kMPerBlock;
@@ -128,6 +131,17 @@ struct MoeFlatmmKernel
static constexpr index_t kNPerIteration = NPerXdl * NWave;
static constexpr index_t kNRepeat = kNPerBlock / kNPerIteration;
static constexpr int OutputNPerBlock =
IsGateUp ? TilePartitioner::NPerBlock / 2 : TilePartitioner::NPerBlock;
// MXF4_Pipeline only has the of scale B and granularityK is 32
static constexpr bool MXFP4_Pipeline = std::is_same_v<BDataType, pk_fp4_t>;
static constexpr int MXFP4N_Pack = 2;
static constexpr int N_Pack = MXFP4_Pipeline ? MXFP4N_Pack : 1;
static constexpr int WeightPackedSize = numeric_traits<BDataType>::PackedSize;
template <class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>>
struct MoeFlatmmKernelArgs
{
@@ -405,10 +419,10 @@ struct MoeFlatmmKernel
const BDataType* b_flat_ptr,
EDataType* e_ptr,
const AccDataType* exp_weight_ptr,
const int expert_id,
const KernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset)
{
// static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
const auto& a_tensor_view = [&]() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
@@ -432,9 +446,9 @@ struct MoeFlatmmKernel
}
}();
index_t kFlatK = FlatmmPipeline::flatKPerWarp * (splitk_batch_offset.splitted_k /
BlockGemmShape::WarpTile::at(number<2>{}));
index_t kFlatK = kargs.K * BlockGemmShape::WarpTile::at(I1); // TODO (support splitK)
index_t kFlatN = kargs.N * kargs.K / kFlatK;
const auto& b_flat_tensor_view = [&]() {
return make_naive_tensor_view<address_space_enum::global>(
b_flat_ptr,
@@ -451,7 +465,7 @@ struct MoeFlatmmKernel
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
e_ptr,
make_tuple(IsInputGemm ? kargs.NumTokens * kargs.TopK : kargs.NumTokens,
kind == MoeFlatmmKind::kFFN_gemm1_gate_up ? kargs.N / 2 : kargs.N),
IsGateUp ? kargs.N / 2 : kargs.N),
make_tuple(kargs.stride_C, 1),
number<EpiloguePipeline::GetVectorSizeC()>{},
number<1>{});
@@ -461,14 +475,30 @@ struct MoeFlatmmKernel
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
e_ptr,
make_tuple(IsInputGemm ? kargs.NumTokens * kargs.TopK : kargs.NumToken,
kind == MoeFlatmmKind::kFFN_gemm1_gate_up ? kargs.N / 2 : kargs.N),
IsGateUp ? kargs.N / 2 : kargs.N),
make_tuple(1, kargs.stride_C),
number<1>{},
number<1>{});
}
}();
return make_tuple(a_tensor_view, b_flat_tensor_view, c_tensor_view);
auto scale_n = kargs.scale_n;
constexpr int GranularityK = decltype(scale_n)::GranularityK;
index_t scale_k = GranularityK == 0 ? 1 : (kargs.K + GranularityK - 1) / GranularityK;
index_t FlatScaleK = scale_k * N_Pack * BlockGemmShape::WarpTile::at(I1);
index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1);
using ScaleType = std::conditional_t<MXFP4_Pipeline, e8m0_t, float>;
const auto scale_b_flat_view = make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const ScaleType*>(scale_n.ptr) + expert_id * kargs.N * scale_k,
make_tuple(FlatScaleN, FlatScaleK),
make_tuple(FlatScaleK, 1),
number<8>{},
number<1>{});
return make_tuple(a_tensor_view, b_flat_tensor_view, c_tensor_view, scale_b_flat_view);
}
template <typename TensorView>
@@ -492,14 +522,9 @@ struct MoeFlatmmKernel
}
}();
const auto& b_flat_tensor_view = views.at(I1);
// TODO vector write in for C in ColMajor
const auto& c_pad_view = [&]() {
const auto& c_tensor_view = views.at(I2);
constexpr int OutputNPerBlock = kind == MoeFlatmmKind::kFFN_gemm1_gate_up
? TilePartitioner::NPerBlock / 2
: TilePartitioner::NPerBlock;
const auto& c_tensor_view = views.at(I2);
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
return pad_tensor_view(
@@ -516,12 +541,13 @@ struct MoeFlatmmKernel
}
}();
return make_tuple(a_pad_view, b_flat_tensor_view, c_pad_view);
return make_tuple(a_pad_view, views.at(I1), c_pad_view, views.at(I3));
}
template <typename PadView>
CK_TILE_DEVICE static auto
MakeGemmTileWindows(const PadView& views, [[maybe_unused]] const index_t i_m, const index_t i_n)
CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views,
[[maybe_unused]] const index_t coord_m,
const index_t coord_n)
{
const auto& a_pad_view = views.at(number<0>{});
const auto& b_flat_pad_view = views.at(number<1>{});
@@ -533,7 +559,7 @@ struct MoeFlatmmKernel
return make_tile_window(a_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_m, 0}); // NOTE!
{coord_m, 0}); // NOTE!
}
else
{
@@ -544,25 +570,33 @@ struct MoeFlatmmKernel
}
}();
const int problem_N_offset = kind == MoeFlatmmKind::kFFN_gemm1_gate_up ? i_n / 2 : i_n;
constexpr bool isNonInterleaveGateUp = !IsGateUp || MXFP4_Pipeline;
const auto& b_flat_block_window = make_tile_window(
b_flat_pad_view,
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
number<FlatmmPipeline::flatKPerWarp>{}),
{static_cast<int>(problem_N_offset / BlockGemmShape::WarpTile::at(I1)), 0});
const auto& b_flat_block_window =
make_tile_window(b_flat_pad_view,
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
number<FlatmmPipeline::flatKPerWarp>{}),
{static_cast<int>(coord_n / BlockGemmShape::WarpTile::at(I1) /
(isNonInterleaveGateUp ? 1 : 2)),
0});
constexpr int OutputNPerBlock = kind == MoeFlatmmKind::kFFN_gemm1_gate_up
? TilePartitioner::NPerBlock / 2
: TilePartitioner::NPerBlock;
const int output_N_offset = IsGateUp ? coord_n / 2 : coord_n;
auto c_block_window = make_tile_window(
c_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<OutputNPerBlock>{}),
{0, // offset_m is included when construct C-scatter-window offsets
problem_N_offset});
output_N_offset});
return make_tuple(a_block_window, b_flat_block_window, c_block_window);
constexpr int GranularityK = 32;
auto scale_block_window = make_tile_window(
views.at(I3),
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
number<FlatmmPipeline::flatKPerWarp * N_Pack * 4 / GranularityK>{}),
{coord_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0});
return make_tuple(a_block_window, b_flat_block_window, c_block_window, scale_block_window);
}
template <class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>>
@@ -614,16 +648,16 @@ struct MoeFlatmmKernel
const ADataType* a_ptr =
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
const BDataType* b_flat_ptr = static_cast<const BDataType*>(kargs.b_ptr) +
splitk_batch_offset.b_k_split_offset +
expert_stride * expert_id;
const BDataType* b_flat_ptr =
static_cast<const BDataType*>(kargs.b_ptr) +
(splitk_batch_offset.b_k_split_offset + expert_stride * expert_id) / WeightPackedSize;
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
const AccDataType* exp_weight_ptr =
static_cast<const AccDataType*>(kargs.p_sorted_expert_weights);
const auto& gemm_tensor_views_tuple = MakeGemmTensorViews(
a_ptr, b_flat_ptr, e_ptr, exp_weight_ptr, kargs, splitk_batch_offset);
a_ptr, b_flat_ptr, e_ptr, exp_weight_ptr, expert_id, kargs, splitk_batch_offset);
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, coord_m, coord_n);
@@ -631,26 +665,43 @@ struct MoeFlatmmKernel
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(number<0>{});
const auto& b_block_window = gemm_tile_windows.at(number<1>{});
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_block_window = gemm_tile_windows.at(I1);
const auto& scale_block_window = gemm_tile_windows.at(I3);
auto a_gather_block_tile =
ck_tile::make_tile_scatter_gather(a_block_window.get_bottom_tensor_view(),
a_block_window.get_window_lengths(),
a_block_window.get_window_origin(),
FlatmmPipeline::GetADramTileDistribution(),
a_dram_dist,
a_offsets); // K DRAM tile window for
auto c_block_tile = FlatmmPipeline{}(a_gather_block_tile,
b_block_window,
number<kind == MoeFlatmmKind::kFFN_gemm1_gate_up>{},
num_loop,
smem_ptr_ping,
smem_ptr_pong);
using AccTile = decltype(c_block_tile);
auto c_block_tile = [&] {
if constexpr(MXFP4_Pipeline)
{
// MXFP4_Pipeline uses gate-up interleave 16 layout for weight
// so don't need extra processing
return FlatmmPipeline{}(a_gather_block_tile,
b_block_window,
scale_block_window, // weight scale with granularityK = 32
num_loop,
smem_ptr_ping,
smem_ptr_pong);
}
else
{
return FlatmmPipeline{}(a_gather_block_tile,
b_block_window,
number<IsGateUp>{},
num_loop,
smem_ptr_ping,
smem_ptr_pong);
}
}();
using AccTile = decltype(c_block_tile);
// Run EpiloguePipeline Pipeline
auto& c_block_window = gemm_tile_windows.at(number<2>{});
using ActivationOp = element_wise::Silu;
{
using EpiProblem = typename EpiloguePipeline::Problem;
@@ -666,26 +717,34 @@ struct MoeFlatmmKernel
constexpr index_t MRepeat = EpiloguePipeline::MRepeat;
constexpr index_t NRepeat = EpiloguePipeline::NRepeat;
constexpr auto lds_block_desc =
EpiloguePipeline::template MakeLdsBlockDescriptor<EpiProblem>();
static_assert(!IsGateUp || NumNXdlPerWavePerShuffle % 2 == 0);
constexpr index_t OutputNumNXdlPerWavePerShuffle =
IsGateUp ? NumNXdlPerWavePerShuffle / 2 : NumNXdlPerWavePerShuffle;
constexpr index_t LDS_NPerIterationShuffle =
IsGateUp ? NPerIterationShuffle / 2 : NPerIterationShuffle;
constexpr auto lds_block_desc = make_naive_tensor_descriptor(
make_tuple(number<MPerIterationShuffle>{}, number<LDS_NPerIterationShuffle>{}),
make_tuple(number<LDS_NPerIterationShuffle>{}, number<1>{}));
// EpiloguePipeline::template MakeLdsBlockDescriptor<EpiProblem>();
auto o_lds_block = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<ODataType*>(smem_ptr_ping), lds_block_desc);
auto in_lds_window = make_tile_window(
o_lds_block,
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
make_tuple(number<MPerIterationShuffle>{}, number<LDS_NPerIterationShuffle>{}),
{0, 0});
auto out_lds_window = make_tile_window(
o_lds_block,
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
make_tuple(number<MPerIterationShuffle>{}, number<LDS_NPerIterationShuffle>{}),
{0, 0});
using SFC = space_filling_curve<
sequence<kMPerBlock,
kind == MoeFlatmmKind::kFFN_gemm1_gate_up ? kNPerBlock / 2 : kNPerBlock>,
sequence<0, 1>,
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
sequence<0, 1>,
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
constexpr index_t num_access = SFC::get_num_of_access();
@@ -696,7 +755,7 @@ struct MoeFlatmmKernel
using TileEncodingPattern = TileDistributionEncodingPattern2D<
kBlockSize,
MPerIterationShuffle,
NPerIterationShuffle,
LDS_NPerIterationShuffle,
kind == MoeFlatmmKind::kFFN_gemm2 ? 2 : EpiloguePipeline::GetVectorSizeC(),
tile_distribution_pattern::thread_raked,
EpiProblem::kNumWaveGroups>;
@@ -704,8 +763,24 @@ struct MoeFlatmmKernel
constexpr auto dram_tile_distribution =
TileEncodingPattern::Make2DStaticTileDistribution();
constexpr auto LdsTileDistr =
make_static_tile_distribution(EpiloguePipeline::MakeLdsDistributionEncode());
constexpr auto LdsTileDistr = [&] {
if constexpr(IsGateUp)
return make_static_tile_distribution(
detail::make_embed_tile_distribution_encoding(
tile_distribution_encoding<
sequence<>,
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
// merge two contiguous N
sequence<OutputNumNXdlPerWavePerShuffle, NWave>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{},
typename CWarpDstr::DstrEncode{}));
else
return make_static_tile_distribution(
EpiloguePipeline::MakeLdsDistributionEncode());
}();
using LDSTileTensor =
decltype(make_static_distributed_tensor<AccDataType>(LdsTileDistr));
@@ -719,8 +794,8 @@ struct MoeFlatmmKernel
constexpr int kM1 = (64 / NPerXdl); // Thr
constexpr int kM0 = MPerXdl / kM1 / kM2; // Val
constexpr int ActVectorSize =
c_warp_y_lengths.product() * NumMXdlPerWavePerShuffle * NumNXdlPerWavePerShuffle;
constexpr int ActVectorSize = c_warp_y_lengths.product() * NumMXdlPerWavePerShuffle *
OutputNumNXdlPerWavePerShuffle;
const index_t iMWarp = get_warp_id() / NWave;
const index_t iNWarp = get_warp_id() - iMWarp * NWave;
@@ -737,32 +812,36 @@ struct MoeFlatmmKernel
//===----------------------------------------------------------------------===//
// Load scales and expert weights
//===----------------------------------------------------------------------===//
if constexpr(kind == MoeFlatmmKind::kFFN_gemm1_gate_up)
if constexpr(!MXFP4_Pipeline)
{
static_for<0, NRepeat / 2, 1>{}([&](auto i) {
vec_scale_B[i] = kargs.scale_n[expert_id * kargs.N + coord_n / 2 +
i * NWave * NPerXdl + iNWarp * NPerXdl + iNLane];
vec_scale_B[i + NRepeat / 2] =
kargs.scale_n[expert_id * kargs.N + kargs.N / 2 + coord_n / 2 +
i * NWave * NPerXdl + iNWarp * NPerXdl + iNLane];
});
if constexpr(IsGateUp)
{
static_for<0, NRepeat / 2, 1>{}([&](auto i) {
vec_scale_B[i * 2] =
kargs.scale_n[expert_id * kargs.N + coord_n / 2 + i * NWave * NPerXdl +
iNWarp * NPerXdl + iNLane];
vec_scale_B[i * 2 + 1] =
kargs.scale_n[expert_id * kargs.N + kargs.N / 2 + coord_n / 2 +
i * NWave * NPerXdl + iNWarp * NPerXdl + iNLane];
});
}
else
{
static_for<0, NRepeat, 1>{}([&](auto i) {
vec_scale_B[i] =
kargs.scale_n[expert_id * kargs.N + coord_n + i * NWave * NPerXdl +
iNWarp * NPerXdl + iNLane];
});
}
}
else
{
static_for<0, NRepeat, 1>{}([&](auto i) {
vec_scale_B[i] = kargs.scale_n[expert_id * kargs.N + coord_n +
i * NWave * NPerXdl + iNWarp * NPerXdl + iNLane];
});
}
static_for<0, MRepeat, 1>{}([&](auto i) {
static_for<0, kM0, 1>{}([&](auto m0) {
static_for<0, kM2, 1>{}([&](auto m2) {
index_t M2_offset = m2 + iMLane * kM2 + m0 * kM2 * kM1 + iMWarp * MPerXdl +
i * MPerXdl * MWave + coord_m;
vec_scale_A[i * kM0 * kM2 + m0 * kM2 + m2] =
kargs.scale_m[row_to_token_idx(M2_offset)];
if constexpr(!MXFP4_Pipeline)
vec_scale_A[i * kM0 * kM2 + m0 * kM2 + m2] =
kargs.scale_m[row_to_token_idx(M2_offset)];
if constexpr(!IsInputGemm)
vec_expert_weights[i * kM0 * kM2 + m0 * kM2 + m2] =
expert_weights[M2_offset];
@@ -770,46 +849,54 @@ struct MoeFlatmmKernel
});
});
constexpr int UpAccStride = NRepeat / 2;
//===----------------------------------------------------------------------===//
// Pingpong process start
//===----------------------------------------------------------------------===//
if constexpr(kind == MoeFlatmmKind::kFFN_gemm1_gate_up)
if constexpr(IsGateUp)
{
LDSTileTensor gate_tensor, up_tensor;
static_assert((NRepeat / NumNXdlPerWavePerShuffle) % 2 == 0);
// gate and up are interleaved along NRepeat dimension.
static_for<0, OutputNumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
gate_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<0, n_xdl>{}, c_warp_y_index_zeros),
merge_sequences(sequence<NumMXdlPerWavePerShuffle, 1>{}, c_warp_y_lengths),
c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<0 * NumMXdlPerWavePerShuffle, 2 * n_xdl>{},
c_warp_y_index_zeros),
merge_sequences(sequence<NumMXdlPerWavePerShuffle, 1>{},
c_warp_y_lengths)));
gate_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(
sequence<0 * NumMXdlPerWavePerShuffle, 0 * NumNXdlPerWavePerShuffle>{},
c_warp_y_index_zeros),
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
c_warp_y_lengths));
up_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<0 * NumMXdlPerWavePerShuffle,
0 * NumNXdlPerWavePerShuffle + UpAccStride>{},
c_warp_y_index_zeros),
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
c_warp_y_lengths));
up_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<0, n_xdl>{}, c_warp_y_index_zeros),
merge_sequences(sequence<NumMXdlPerWavePerShuffle, 1>{}, c_warp_y_lengths),
c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<0 * NumMXdlPerWavePerShuffle, 2 * n_xdl + 1>{},
c_warp_y_index_zeros),
merge_sequences(sequence<NumMXdlPerWavePerShuffle, 1>{},
c_warp_y_lengths)));
});
static_for<0, NumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
constexpr int acc_xdl_offset =
(m_xdl + n_xdl * NumMXdlPerWavePerShuffle) * c_warp_y_lengths.product();
static_for<0, kM0, 1>{}([&](auto m0) {
static_for<0, kM2, 1>{}([&](auto m2) {
gate_tensor.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
vec_scale_B[n_xdl];
up_tensor.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
vec_scale_B[n_xdl + UpAccStride];
if constexpr(!MXFP4_Pipeline)
static_for<0, OutputNumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
constexpr int acc_xdl_offset =
(m_xdl * OutputNumNXdlPerWavePerShuffle + n_xdl) *
c_warp_y_lengths.product();
static_for<0, kM0, 1>{}([&](auto m0) {
static_for<0, kM2, 1>{}([&](auto m2) {
gate_tensor
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
vec_scale_B[2 * n_xdl];
up_tensor.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
vec_scale_B[2 * n_xdl + 1];
});
});
});
});
});
static_for<0, ActVectorSize, 1>{}([&](auto idx) {
ActivationOp{}(gate_tensor.get_thread_buffer().at(idx),
@@ -830,16 +917,18 @@ struct MoeFlatmmKernel
static_for<0, NumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
constexpr int acc_xdl_offset =
(m_xdl + n_xdl * NumMXdlPerWavePerShuffle) * c_warp_y_lengths.product();
(m_xdl * NumNXdlPerWavePerShuffle + n_xdl) * c_warp_y_lengths.product();
static_for<0, kM0, 1>{}([&](auto m0) {
static_for<0, kM2, 1>{}([&](auto m2) {
if constexpr(!IsInputGemm)
lds_tile[0]
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
vec_expert_weights[m_xdl * kM0 * kM2 + m0 * kM2 + m2];
lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
vec_scale_B[n_xdl];
if constexpr(!MXFP4_Pipeline)
lds_tile[0]
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
vec_scale_B[n_xdl];
});
});
});
@@ -875,51 +964,62 @@ struct MoeFlatmmKernel
if constexpr(iAccess < num_access - 1)
{
if constexpr(kind == MoeFlatmmKind::kFFN_gemm1_gate_up)
if constexpr(IsGateUp)
{
LDSTileTensor gate_tensor, up_tensor;
gate_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter_next * NumMXdlPerWavePerShuffle,
nIter_next * NumNXdlPerWavePerShuffle>{},
c_warp_y_index_zeros),
merge_sequences(
sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
c_warp_y_lengths));
up_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
merge_sequences(
sequence<mIter_next * NumMXdlPerWavePerShuffle,
nIter_next * NumNXdlPerWavePerShuffle + UpAccStride>{},
c_warp_y_index_zeros),
merge_sequences(
sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
c_warp_y_lengths));
static_for<0, OutputNumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
gate_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<0, n_xdl>{}, c_warp_y_index_zeros),
merge_sequences(sequence<NumMXdlPerWavePerShuffle, 1>{},
c_warp_y_lengths),
c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter_next * NumMXdlPerWavePerShuffle,
nIter_next * NumNXdlPerWavePerShuffle +
2 * n_xdl>{},
c_warp_y_index_zeros),
merge_sequences(sequence<NumMXdlPerWavePerShuffle, 1>{},
c_warp_y_lengths)));
static_for<0, NumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
constexpr int acc_xdl_offset =
(m_xdl + n_xdl * NumMXdlPerWavePerShuffle) *
c_warp_y_lengths.product();
static_for<0, kM0, 1>{}([&](auto m0) {
static_for<0, kM2, 1>{}([&](auto m2) {
gate_tensor
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
vec_scale_A[mIter_next * NumMXdlPerWavePerShuffle *
kM0 * kM2 +
m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
vec_scale_B[nIter_next * NumNXdlPerWavePerShuffle +
n_xdl];
up_tensor
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
vec_scale_A[mIter_next * NumMXdlPerWavePerShuffle *
kM0 * kM2 +
m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
vec_scale_B[nIter_next * NumNXdlPerWavePerShuffle +
n_xdl + UpAccStride];
up_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<0, n_xdl>{}, c_warp_y_index_zeros),
merge_sequences(sequence<NumMXdlPerWavePerShuffle, 1>{},
c_warp_y_lengths),
c_block_tile.get_y_sliced_thread_data(
merge_sequences(sequence<mIter_next * NumMXdlPerWavePerShuffle,
nIter_next * NumNXdlPerWavePerShuffle +
2 * n_xdl + 1>{},
c_warp_y_index_zeros),
merge_sequences(sequence<NumMXdlPerWavePerShuffle, 1>{},
c_warp_y_lengths)));
});
if constexpr(!MXFP4_Pipeline)
static_for<0, OutputNumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
constexpr int acc_xdl_offset =
(m_xdl * OutputNumNXdlPerWavePerShuffle + n_xdl) *
c_warp_y_lengths.product();
static_for<0, kM0, 1>{}([&](auto m0) {
static_for<0, kM2, 1>{}([&](auto m2) {
gate_tensor.get_thread_buffer()[acc_xdl_offset +
m0 * kM2 + m2] *=
vec_scale_A[mIter_next * NumMXdlPerWavePerShuffle *
kM0 * kM2 +
m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
vec_scale_B[nIter_next * NumNXdlPerWavePerShuffle +
2 * n_xdl];
up_tensor.get_thread_buffer()[acc_xdl_offset +
m0 * kM2 + m2] *=
vec_scale_A[mIter_next * NumMXdlPerWavePerShuffle *
kM0 * kM2 +
m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
vec_scale_B[nIter_next * NumNXdlPerWavePerShuffle +
2 * n_xdl + 1];
});
});
});
});
});
static_for<0, ActVectorSize, 1>{}([&](auto idx) {
ActivationOp{}(gate_tensor.get_thread_buffer().at(idx),
gate_tensor.get_thread_buffer().at(idx));
@@ -941,7 +1041,7 @@ struct MoeFlatmmKernel
static_for<0, NumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
constexpr int acc_xdl_offset =
(m_xdl + n_xdl * NumMXdlPerWavePerShuffle) *
(m_xdl * NumNXdlPerWavePerShuffle + n_xdl) *
c_warp_y_lengths.product();
static_for<0, kM0, 1>{}([&](auto m0) {
static_for<0, kM2, 1>{}([&](auto m2) {
@@ -951,13 +1051,15 @@ struct MoeFlatmmKernel
m2] *= vec_expert_weights
[mIter_next * NumMXdlPerWavePerShuffle * kM0 * kM2 +
m_xdl * kM0 * kM2 + m0 * kM2 + m2];
lds_tile[write_stage]
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
vec_scale_A[mIter_next * NumMXdlPerWavePerShuffle *
kM0 * kM2 +
m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
vec_scale_B[nIter_next * NumNXdlPerWavePerShuffle +
n_xdl];
if constexpr(!MXFP4_Pipeline)
lds_tile[write_stage]
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 +
m2] *=
vec_scale_A[mIter_next * NumMXdlPerWavePerShuffle *
kM0 * kM2 +
m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
vec_scale_B[nIter_next * NumNXdlPerWavePerShuffle +
n_xdl];
});
});
});

View File

@@ -582,19 +582,20 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
if constexpr(!IsGateUpMode || nIter < NIterPerWarp / 2)
{
if constexpr(!IsGateUpMode)
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
}
else
{
move_tile_window(
b_flat_dram_windows(nIter)(kIter),
{(nIter - NIterPerWarp / 2) * NFlatPerBlockPerIter + up_weight_stride,
kIter * KFlatPerBlockPerIter});
if constexpr(nIter % 2 == 0)
move_tile_window(
b_flat_dram_windows(nIter)(kIter),
{nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
else
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter / 2 * NFlatPerBlockPerIter + up_weight_stride,
kIter * KFlatPerBlockPerIter});
}
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
});
});
@@ -637,18 +638,20 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
if constexpr(!IsGateUpMode || nIter < NIterPerWarp / 2)
{
if constexpr(!IsGateUpMode)
move_tile_window(
b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
}
else
{
move_tile_window(
b_flat_dram_windows(nIter)(kIter),
{(nIter - NIterPerWarp / 2) * NFlatPerBlockPerIter + up_weight_stride,
kIter * KFlatPerBlockPerIter});
if constexpr(nIter % 2 == 0)
move_tile_window(
b_flat_dram_windows(nIter)(kIter),
{nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
else
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter / 2 * NFlatPerBlockPerIter + up_weight_stride,
kIter * KFlatPerBlockPerIter});
}
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
@@ -723,18 +726,20 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
if constexpr(!IsGateUpMode || nIter < NIterPerWarp / 2)
{
if constexpr(!IsGateUpMode)
move_tile_window(
b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
}
else
{
move_tile_window(
b_flat_dram_windows(nIter)(kIter),
{(nIter - NIterPerWarp / 2) * NFlatPerBlockPerIter + up_weight_stride,
kIter * KFlatPerBlockPerIter});
if constexpr(nIter % 2 == 0)
move_tile_window(
b_flat_dram_windows(nIter)(kIter),
{nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
else
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter / 2 * NFlatPerBlockPerIter + up_weight_stride,
kIter * KFlatPerBlockPerIter});
}
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
@@ -812,18 +817,20 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
if constexpr(!IsGateUpMode || nIter < NIterPerWarp / 2)
{
if constexpr(!IsGateUpMode)
move_tile_window(
b_flat_dram_windows(nIter)(kIter),
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
}
else
{
move_tile_window(
b_flat_dram_windows(nIter)(kIter),
{(nIter - NIterPerWarp / 2) * NFlatPerBlockPerIter + up_weight_stride,
kIter * KFlatPerBlockPerIter});
if constexpr(nIter % 2 == 0)
move_tile_window(
b_flat_dram_windows(nIter)(kIter),
{nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
else
move_tile_window(b_flat_dram_windows(nIter)(kIter),
{nIter / 2 * NFlatPerBlockPerIter + up_weight_stride,
kIter * KFlatPerBlockPerIter});
}
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));