mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-05 20:55:59 +00:00
add fp16xf4 moe
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
add_executable(tile_example_moe_flatmm EXCLUDE_FROM_ALL moe_flatmm.cpp)
|
||||
add_executable(tile_example_a16w4_moe_flatmm EXCLUDE_FROM_ALL mixed_prec/a16w4_moe_flatmm.cpp)
|
||||
|
||||
|
||||
set(EXAMPLE_MOE_FLATMM_COMPILE_OPTIONS)
|
||||
|
||||
@@ -9,3 +11,4 @@ endif()
|
||||
|
||||
list(APPEND EXAMPLE_MOE_FLATMM_COMPILE_OPTIONS --save-temps)
|
||||
target_compile_options(tile_example_moe_flatmm PRIVATE ${EXAMPLE_MOE_FLATMM_COMPILE_OPTIONS})
|
||||
target_compile_options(tile_example_a16w4_moe_flatmm PRIVATE ${EXAMPLE_MOE_FLATMM_COMPILE_OPTIONS})
|
||||
|
||||
534
example/ck_tile/21_moe_flatmm/mixed_prec/a16w4_moe_flatmm.cpp
Normal file
534
example/ck_tile/21_moe_flatmm/mixed_prec/a16w4_moe_flatmm.cpp
Normal file
@@ -0,0 +1,534 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <memory>
|
||||
|
||||
#include "a16w4_moe_flatmm.hpp"
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/flatmm.hpp"
|
||||
#include "ck_tile/ops/moe_flatmm.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/host/reference/reference_moe_gemm.hpp"
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
// gemm1
|
||||
// operand-A = [num_token, d_model]
|
||||
// operand-B = [num_expert, hidden, d_model]
|
||||
// operand-C = [num_token, topk, hidden]
|
||||
|
||||
// gemm2
|
||||
// operand-A = [num_token, topk, hidden]
|
||||
// operand-B = [num_expert, d_model, hidden]
|
||||
// operand-C = [num_token, d_model]
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ck_tile::MoeFlatmmKind moe_kind = ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough,
|
||||
typename ScaleM,
|
||||
typename ScaleN>
|
||||
float a16w4_moe_gemm(const ck_tile::MoeFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
using CodegenFlatmmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<FlatmmConfig::M_Tile, FlatmmConfig::N_Tile, FlatmmConfig::K_Tile>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp, FlatmmConfig::N_Warp, FlatmmConfig::K_Warp>,
|
||||
ck_tile::sequence<FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<CodegenFlatmmShape,
|
||||
FlatmmConfig::TileParitionerGroupNum,
|
||||
FlatmmConfig::TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<FlatmmConfig::kPadM,
|
||||
FlatmmConfig::kPadN,
|
||||
FlatmmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
FlatmmConfig::NumWaveGroups>;
|
||||
|
||||
using CodegenGemmTraits = ck_tile::TileGemmUniversalTraits<FlatmmConfig::kPadM,
|
||||
FlatmmConfig::kPadN,
|
||||
FlatmmConfig::kPadK,
|
||||
FlatmmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
FlatmmConfig::TransposeC,
|
||||
FlatmmConfig::UseStructuredSparsity,
|
||||
false, // UsePersistentKernel_
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
true>; // Preshuffle_
|
||||
|
||||
constexpr bool MXFP4_Pipeline = std::is_same_v<BDataType, ck_tile::pk_fp4_t>;
|
||||
|
||||
if constexpr(!MXFP4_Pipeline && moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up)
|
||||
{
|
||||
static_assert(
|
||||
FlatmmConfig::N_Tile % (FlatmmConfig::N_Warp * FlatmmConfig::N_Warp_Tile * 2) == 0,
|
||||
"requires NRepeat is multiple of 2 for FFN_gemm1_gate_up");
|
||||
}
|
||||
|
||||
using ComputeDataType = ADataType;
|
||||
static_assert(sizeof(ComputeDataType) >= sizeof(BDataType),
|
||||
"mixed_prec_flatmm requires ADataType is a wider type than BDataType");
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblem<ComputeDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
CodegenFlatmmShape,
|
||||
Traits>;
|
||||
|
||||
using BaseGemmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * FlatmmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * FlatmmConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = FlatmmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using CodegenPipelineProblem =
|
||||
std::conditional_t<MXFP4_Pipeline,
|
||||
ck_tile::F16xMXF4FlatmmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenFlatmmShape,
|
||||
CodegenGemmTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>,
|
||||
ck_tile::FlatmmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenFlatmmShape,
|
||||
CodegenGemmTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>>;
|
||||
|
||||
constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ComputeDataType,
|
||||
ComputeDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
CodegenPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
FlatmmConfig::M_Warp,
|
||||
FlatmmConfig::N_Warp,
|
||||
FlatmmConfig::M_Warp_Tile,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
FlatmmConfig::K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
memory_operation,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false,
|
||||
1,
|
||||
FlatmmConfig::TiledMMAPermuteN,
|
||||
BlockedXDLN_PerWarp>>;
|
||||
|
||||
using CodegenFlatmmPipeline = std::conditional_t<
|
||||
MXFP4_Pipeline,
|
||||
ck_tile::F16xMXF4FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>,
|
||||
ck_tile::MoeFlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>>;
|
||||
|
||||
using Kernel = ck_tile::
|
||||
MoeFlatmmKernel<TilePartitioner, CodegenFlatmmPipeline, GemmEpilogue, moe_kind>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(kargs);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args:" << CodegenFlatmmShape::GetName() << "\n"
|
||||
<< "Shape: " << CodegenFlatmmShape::GetName() << "\n"
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << "\n"
|
||||
<< "pipeline: " << CodegenFlatmmPipeline::GetName() << "\n"
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
static constexpr ck_tile::index_t APackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ||
|
||||
std::is_same_v<BDataType, ck_tile::pk_fp4_t>
|
||||
? 2
|
||||
: 1;
|
||||
static constexpr ck_tile::index_t BPackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ||
|
||||
std::is_same_v<BDataType, ck_tile::pk_fp4_t>
|
||||
? 2
|
||||
: 1;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm2 ? args.NumTokens * args.TopK
|
||||
: args.NumTokens,
|
||||
args.K,
|
||||
args.stride_A,
|
||||
is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N * args.NumExperts, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
const int outputN =
|
||||
moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up ? args.N / 2 : args.N;
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
|
||||
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm2)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.NumTokens * args.N * sizeof(CDataType), s.stream_id_));
|
||||
else if(args.k_batch > 1)
|
||||
hipGetErrorString(
|
||||
hipMemsetAsync(args.e_ptr,
|
||||
0,
|
||||
args.NumTokens * args.TopK * outputN * sizeof(CDataType),
|
||||
s.stream_id_));
|
||||
};
|
||||
ave_time = ck_tile::launch_kernel_preprocess(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<blocks.x, FlatmmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
};
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <class FlatmmConfig, ck_tile::MoeFlatmmKind moe_kind, class IterSrc, class IterDst>
|
||||
void shuffle_mxfp4_weight(const IterSrc src, IterDst dst, int experts_cnt, int N, int K)
|
||||
{
|
||||
int KPack = 16;
|
||||
int NLane = FlatmmConfig::N_Warp_Tile;
|
||||
int KLane = 64 / NLane;
|
||||
int K_pk = K / 2;
|
||||
int K0 = K_pk / (KLane * KPack);
|
||||
// K -> K0 KLane KPack
|
||||
// N -> N0 NLane
|
||||
// N, K -> N0 K0 KLane NLane KPack
|
||||
int tempk;
|
||||
|
||||
if constexpr(moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up)
|
||||
{
|
||||
int up_stride = N / 2 / NLane;
|
||||
|
||||
for(int eid = 0; eid < experts_cnt; ++eid)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
for(int k = 0; k < K_pk; ++k)
|
||||
{
|
||||
int n0 = n / NLane;
|
||||
int n1 = n % NLane;
|
||||
|
||||
// interleave gate and up part with granularity is 16.
|
||||
int n0_interleave = n >= N / 2 ? (n0 - up_stride) * 2 + 1 : // up part
|
||||
n0 * 2; // gate part
|
||||
|
||||
int k0 = k / (KLane * KPack);
|
||||
tempk = k % (KLane * KPack);
|
||||
int k1 = tempk / KPack;
|
||||
int k2 = tempk % KPack;
|
||||
|
||||
int outputIndex = eid * N * K_pk + n0_interleave * KPack * NLane * KLane * K0 +
|
||||
k0 * KPack * NLane * KLane + k1 * KPack * NLane + n1 * KPack +
|
||||
k2;
|
||||
|
||||
dst[outputIndex] = src[eid * N * K_pk + n * K_pk + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for(int eid = 0; eid < experts_cnt; ++eid)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
for(int k = 0; k < K_pk; ++k)
|
||||
{
|
||||
int n0 = n / NLane;
|
||||
int n1 = n % NLane;
|
||||
|
||||
int k0 = k / (KLane * KPack);
|
||||
tempk = k % (KLane * KPack);
|
||||
int k1 = tempk / KPack;
|
||||
int k2 = tempk % KPack;
|
||||
|
||||
int outputIndex = eid * N * K_pk + n0 * KPack * NLane * KLane * K0 +
|
||||
k0 * KPack * NLane * KLane + k1 * KPack * NLane + n1 * KPack +
|
||||
k2;
|
||||
|
||||
dst[outputIndex] = src[eid * N * K_pk + n * K_pk + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename FlatmmConfig, ck_tile::MoeFlatmmKind moe_kind, typename T>
|
||||
auto shuffle_mxfp4_scale(const ck_tile::HostTensor<T>& scale, int experts_cnt)
|
||||
{
|
||||
assert(scale.get_lengths().size() == 2);
|
||||
int n_ = scale.get_lengths()[1];
|
||||
int k_ = scale.get_lengths()[0];
|
||||
|
||||
int k_per_expert = k_ / experts_cnt;
|
||||
|
||||
constexpr int K_Pack = 2; // fixed for mxfp4
|
||||
constexpr int N_Pack = 2; // fixed for mxfp4
|
||||
constexpr int GranularityK = 32; // fixed for mxfp4
|
||||
|
||||
constexpr int K_Lane = 64 / FlatmmConfig::N_Warp_Tile; // 4
|
||||
|
||||
static_assert(FlatmmConfig::N_Warp_Tile == 16, "only support XDL_N == 16");
|
||||
static_assert(FlatmmConfig::N_Repeat % N_Pack == 0);
|
||||
static_assert(FlatmmConfig::K_Tile % (K_Pack * K_Lane * GranularityK) == 0);
|
||||
|
||||
if constexpr(moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up)
|
||||
{
|
||||
ck_tile::HostTensor<T> shfl_scale({
|
||||
experts_cnt,
|
||||
k_per_expert / K_Pack / K_Lane,
|
||||
K_Pack,
|
||||
K_Lane,
|
||||
N_Pack, // N_Pack = 2 is composed of Gate + Up.
|
||||
n_ / FlatmmConfig::N_Warp_Tile / N_Pack,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
});
|
||||
std::copy(scale.begin(), scale.end(), shfl_scale.begin());
|
||||
return ck_tile::reference_permute(shfl_scale, {0, 5, 1, 3, 6, 2, 4});
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::HostTensor<T> shfl_scale({
|
||||
experts_cnt,
|
||||
k_per_expert / K_Pack / K_Lane,
|
||||
K_Pack,
|
||||
K_Lane,
|
||||
n_ / FlatmmConfig::N_Warp_Tile / N_Pack,
|
||||
N_Pack,
|
||||
FlatmmConfig::N_Warp_Tile,
|
||||
});
|
||||
std::copy(scale.begin(), scale.end(), shfl_scale.begin());
|
||||
return ck_tile::reference_permute(shfl_scale, {0, 4, 1, 3, 6, 2, 5});
|
||||
}
|
||||
}
|
||||
|
||||
#include "run_a16w4_moe_flatmm_example.inc"
|
||||
|
||||
template <typename FlatmmConfig>
|
||||
int run_a16w4_moe_flatmm_example(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
|
||||
const std::string a_layout = arg_parser.get_str("a_layout");
|
||||
const std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
const std::string mixed_prec = arg_parser.get_str("mixed_prec");
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
const std::string gemm_kind = arg_parser.get_str("gemm_kind");
|
||||
if(gemm_kind == "gemm1_gate_up")
|
||||
{
|
||||
if(mixed_prec == "fp16xfp4")
|
||||
{
|
||||
return run_a16w4_moe_gemm_example_with_layouts<
|
||||
ck_tile::half_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(mixed_prec == "bf16xfp4")
|
||||
{
|
||||
return run_a16w4_moe_gemm_example_with_layouts<
|
||||
ck_tile::bfloat16_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported precision type for gemm1_gate_up!");
|
||||
}
|
||||
}
|
||||
else if(gemm_kind == "gemm1_gate_only")
|
||||
{
|
||||
if(mixed_prec == "fp16xfp4")
|
||||
{
|
||||
return run_a16w4_moe_gemm_example_with_layouts<
|
||||
ck_tile::half_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(mixed_prec == "bf16xfp4")
|
||||
{
|
||||
return run_a16w4_moe_gemm_example_with_layouts<
|
||||
ck_tile::bfloat16_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_only>(argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported precision type for gemm1_gate_only!");
|
||||
}
|
||||
}
|
||||
else if(gemm_kind == "gemm2")
|
||||
{
|
||||
if(mixed_prec == "fp16xfp4")
|
||||
{
|
||||
return run_a16w4_moe_gemm_example_with_layouts<ck_tile::half_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm2>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(mixed_prec == "bf16xfp4")
|
||||
{
|
||||
return run_a16w4_moe_gemm_example_with_layouts<ck_tile::bfloat16_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
FlatmmConfig,
|
||||
ck_tile::MoeFlatmmKind::kFFN_gemm2>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported precision type for gemm2!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unrecoginized gemm_kind parameter, only accept value "
|
||||
"[gemm1_gate_only | gemm1_gate_up | gemm2]");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!");
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return EXIT_FAILURE;
|
||||
|
||||
try
|
||||
{
|
||||
int warp_tile = arg_parser.get_int("warp_tile");
|
||||
if(warp_tile == 0)
|
||||
{
|
||||
return !run_a16w4_moe_flatmm_example<A16W4_FlatmmConfig16>(argc, argv);
|
||||
}
|
||||
// else if(warp_tile == 1)
|
||||
// {
|
||||
// return !run_a16w4_moe_flatmm_example<A16W4_FlatmmConfig16_950>(argc, argv);
|
||||
// }
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Runtime error: " << e.what() << '\n';
|
||||
return EXIT_FAILURE;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,86 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/moe_flatmm.hpp"
|
||||
|
||||
// GEMM config with 16x16 warp tile
|
||||
struct A16W4_FlatmmConfig16
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 64;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 32;
|
||||
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr int TileParitionerGroupNum = 8;
|
||||
static constexpr int TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
};
|
||||
|
||||
struct A16W4_FlatmmConfig16_950 : public A16W4_FlatmmConfig16
|
||||
{
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
|
||||
static constexpr int N_Repeat =
|
||||
N_Tile / A16W4_FlatmmConfig16::N_Warp_Tile / A16W4_FlatmmConfig16::N_Warp;
|
||||
static constexpr bool TiledMMAPermuteN = false;
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("experts", "8", "Num of experts - 8 by default")
|
||||
.insert("NumTokens", "128", "M dimensions - 128 by default.")
|
||||
.insert("TopK", "3", "Top K - 3 by default.")
|
||||
.insert("N", "4096", "N dimensions - 4096 by default.")
|
||||
.insert("K", "4096", "K dimensions - 4096 by default.")
|
||||
.insert("stride_A", "", "Tensor A strides - it is empty by default.")
|
||||
.insert("stride_B", "", "Tensor B strides - it is empty by default.")
|
||||
.insert("stride_C", "", "Tensor C strides - it is empty by default.")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default.")
|
||||
.insert("b_layout", "C", "B tensor data layout - Col by default.")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default.")
|
||||
.insert("gemm_kind",
|
||||
"gemm1_gate_only",
|
||||
"Gemm kind in FFN network [gemm1_gate_only | gemm1_gate_up | gemm2] - "
|
||||
"gemm1_gate_only by default.")
|
||||
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("mixed_prec",
|
||||
"bf16xfp4",
|
||||
"data type for activation and weight, support: bf16xfp4, fp16xfp4")
|
||||
.insert("init", "0", "0:random, 1:constant(1)")
|
||||
.insert(
|
||||
"warp_tile", "0", "0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)")
|
||||
.insert("repeat", "10", "number of iterations to benchmark the kernel.");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
@@ -0,0 +1,365 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
template <typename FlatmmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDatatype,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ck_tile::MoeFlatmmKind kind,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough,
|
||||
typename MoeHostArgs>
|
||||
float invoke_a16w4_moe_gemm(int n_warmup, int n_repeat, const MoeHostArgs& args)
|
||||
{
|
||||
float ave_time = a16w4_moe_gemm<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDatatype,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
kind,
|
||||
CDEElementWise>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
|
||||
std::string op_name{"Moe Gemm"};
|
||||
|
||||
constexpr int PackedSize = ck_tile::numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
std::size_t flop = std::size_t(2) * args.M * args.N * args.K;
|
||||
std::size_t num_byte = sizeof(ADataType) * args.M * args.K +
|
||||
sizeof(BDataType) * args.N * args.K / PackedSize +
|
||||
sizeof(CDataType) * args.M * args.N;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
struct LocalSilu
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE T operator()(const T& x) const
|
||||
{
|
||||
T y;
|
||||
ck_tile::element_wise::Silu{}(y, x);
|
||||
return y;
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename PrecActType,
|
||||
typename PrecWeightType,
|
||||
typename FlatmmConfig,
|
||||
ck_tile::MoeFlatmmKind kind,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
int run_a16w4_moe_gemm_example_with_layouts(int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
[[maybe_unused]] const CLayout c_layout = CLayout{})
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
|
||||
if(!result)
|
||||
{
|
||||
return -1;
|
||||
};
|
||||
|
||||
using ADataType = PrecActType;
|
||||
using BDataType = PrecWeightType;
|
||||
using CDataType = PrecActType;
|
||||
using AccDataType = float;
|
||||
|
||||
using ScaleType = ck_tile::e8m0_t;
|
||||
|
||||
constexpr int ScaleGranularityN = 1;
|
||||
constexpr int ScaleGranularityK = 32;
|
||||
|
||||
const ck_tile::index_t N = arg_parser.get_int("N");
|
||||
const ck_tile::index_t K = arg_parser.get_int("K");
|
||||
ck_tile::index_t stride_A = arg_parser.get_int("stride_A");
|
||||
ck_tile::index_t stride_B = arg_parser.get_int("stride_B");
|
||||
ck_tile::index_t stride_C = arg_parser.get_int("stride_C");
|
||||
ck_tile::index_t init_method = arg_parser.get_int("init");
|
||||
const ck_tile::index_t num_tokens = arg_parser.get_int("NumTokens");
|
||||
const ck_tile::index_t topk = arg_parser.get_int("TopK");
|
||||
const ck_tile::index_t warmup = arg_parser.get_int("warmup");
|
||||
const ck_tile::index_t repeat = arg_parser.get_int("repeat");
|
||||
const ck_tile::index_t experts = arg_parser.get_int("experts");
|
||||
|
||||
// TODO: replace the magic declaration
|
||||
const ck_tile::index_t MPerBlock = FlatmmConfig::M_Tile;
|
||||
|
||||
ck_tile::index_t sorted_tile_num = num_tokens * topk / MPerBlock;
|
||||
ck_tile::index_t valid_tile_num = sorted_tile_num;
|
||||
ck_tile::index_t sorted_size = sorted_tile_num * MPerBlock;
|
||||
|
||||
const ck_tile::index_t M = sorted_tile_num * MPerBlock;
|
||||
const ck_tile::index_t outputN = kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up ? N / 2 : N;
|
||||
|
||||
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
|
||||
constexpr bool IsInputGemm = kind != ck_tile::MoeFlatmmKind::kFFN_gemm2;
|
||||
|
||||
stride_A = ck_tile::get_default_stride(
|
||||
IsInputGemm ? num_tokens : num_tokens * topk, K, stride_A, is_row_major(a_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
|
||||
stride_C = ck_tile::get_default_stride(
|
||||
IsInputGemm ? num_tokens * topk : num_tokens, outputN, stride_C, is_row_major(CLayout{}));
|
||||
|
||||
auto a_m_k_tensor = ck_tile::HostTensor<ADataType>(ck_tile::host_tensor_descriptor(
|
||||
IsInputGemm ? num_tokens : num_tokens * topk, K, stride_A, is_row_major(a_layout)));
|
||||
auto b_k_n_tensor = ck_tile::HostTensor<BDataType>(
|
||||
is_row_major(b_layout)
|
||||
? ck_tile::host_tensor_descriptor(experts * N, K, stride_B, is_row_major(b_layout))
|
||||
: ck_tile::host_tensor_descriptor(K, experts * N, stride_B, is_row_major(b_layout)));
|
||||
auto c_m_n_tensor = ck_tile::HostTensor<CDataType>(ck_tile::host_tensor_descriptor(
|
||||
IsInputGemm ? num_tokens * topk : num_tokens, outputN, stride_C, is_row_major(CLayout{})));
|
||||
|
||||
ck_tile::HostTensor<ScaleType> scale_b(ck_tile::HostTensorDescriptor(
|
||||
{K * experts / ScaleGranularityK, N / ScaleGranularityN}, {N / ScaleGranularityN, 1}));
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{0.0f, 1.0f}(a_m_k_tensor);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_k_n_tensor);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{0.f, 1.f}(scale_b);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{1.0f, 1.0f}(a_m_k_tensor);
|
||||
ck_tile::FillUniformDistribution<BDataType>{1.0f, 1.0f}(b_k_n_tensor);
|
||||
ck_tile::FillUniformDistribution<ScaleType>{1.0f, 1.0f}(scale_b);
|
||||
}
|
||||
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host(
|
||||
ck_tile::host_tensor_descriptor(K, experts * N, stride_B, is_row_major(b_layout)));
|
||||
shuffle_mxfp4_weight<FlatmmConfig, kind>(
|
||||
b_k_n_tensor.begin(), b_shuffle_host.begin(), experts, N, K);
|
||||
|
||||
ck_tile::HostTensor<ScaleType> scale_b_shuffle =
|
||||
shuffle_mxfp4_scale<FlatmmConfig, kind>(scale_b, experts);
|
||||
ck_tile::DeviceMem scale_b_shuffle_dev_buf(scale_b_shuffle.get_element_space_size_in_bytes());
|
||||
|
||||
std::cout << "moe_flatmm:"
|
||||
<< "\n num_experts: " << experts << "\n num_tokens: " << num_tokens
|
||||
<< "\n topk: " << topk << "\n sorted_tile_num: " << sorted_tile_num
|
||||
<< "\n problem_n: " << N << "\n problem_k: " << K
|
||||
<< "\n a_m_k: " << a_m_k_tensor.mDesc << "\n b_k_n: " << b_k_n_tensor.mDesc
|
||||
<< "\n b_shuffle: " << b_shuffle_host.mDesc << "\n c_m_n: " << c_m_n_tensor.mDesc
|
||||
<< std::endl;
|
||||
|
||||
ck_tile::HostTensor<ck_tile::index_t> expert_ids(
|
||||
ck_tile::HostTensorDescriptor({sorted_tile_num}, {1}));
|
||||
ck_tile::HostTensor<ck_tile::index_t> sorted_token_ids(
|
||||
ck_tile::HostTensorDescriptor({sorted_size}, {1}));
|
||||
ck_tile::HostTensor<AccDataType> expert_weight(
|
||||
ck_tile::HostTensorDescriptor({sorted_size}, {1}));
|
||||
ck_tile::HostTensor<ck_tile::index_t> max_token_id(
|
||||
ck_tile::HostTensorDescriptor({1 + sorted_tile_num}));
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
// for verification only, no need to satify weight normalization
|
||||
ck_tile::FillUniformDistribution<AccDataType>{0.0f, 1.0f}(expert_weight);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<AccDataType>{1.0f, 1.0f}(expert_weight);
|
||||
}
|
||||
|
||||
max_token_id.mData = {valid_tile_num * MPerBlock, 0, 1, 2, 3, 4, 6, 7, 8, 8};
|
||||
// int eids[] = {0, 1, 2, 3, 4, 4, 5, 6, 3, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2}
|
||||
|
||||
for(int i = 0; i < sorted_tile_num; i++)
|
||||
{
|
||||
expert_ids.mData[i] = i / ((valid_tile_num + experts - 1) / experts);
|
||||
}
|
||||
|
||||
int token_per_tile = (num_tokens * topk + valid_tile_num - 1) / valid_tile_num;
|
||||
// int token_per_tile = num_tokens * topk / valid_tile_num;
|
||||
int tokenid = 0;
|
||||
// sorted_token_ids.mData[0] = 0;
|
||||
for(int i = 0; i < sorted_tile_num * MPerBlock; i++)
|
||||
{
|
||||
int tile_off = i % MPerBlock;
|
||||
if(tile_off < token_per_tile && tokenid < num_tokens * topk)
|
||||
{
|
||||
sorted_token_ids.mData[i] = (tokenid % num_tokens) | ((tokenid / num_tokens) << 24);
|
||||
tokenid++;
|
||||
}
|
||||
else
|
||||
{
|
||||
sorted_token_ids.mData[i] = num_tokens;
|
||||
}
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem a_m_k_dev_buf{a_m_k_tensor.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem b_origin_dev_buf{b_k_n_tensor.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem b_shuffle_dev_buf{b_shuffle_host.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem c_m_n_dev_buf{c_m_n_tensor.get_element_space_size_in_bytes()};
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k_tensor.data());
|
||||
b_origin_dev_buf.ToDevice(b_k_n_tensor.data());
|
||||
b_shuffle_dev_buf.ToDevice(b_shuffle_host.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_tensor.SetZero();
|
||||
|
||||
ck_tile::DeviceMem sorted_token_ids_dev{sizeof(ck_tile::index_t) *
|
||||
sorted_token_ids.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem expert_ids_dev{sizeof(ck_tile::index_t) *
|
||||
expert_ids.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem max_token_id_dev{sizeof(ck_tile::index_t) *
|
||||
max_token_id.get_element_space_size_in_bytes()};
|
||||
ck_tile::DeviceMem expert_weight_dev{sizeof(AccDataType) *
|
||||
expert_weight.get_element_space_size_in_bytes()};
|
||||
|
||||
sorted_token_ids_dev.ToDevice(sorted_token_ids.data());
|
||||
expert_ids_dev.ToDevice(expert_ids.data());
|
||||
max_token_id_dev.ToDevice(max_token_id.data());
|
||||
expert_weight_dev.ToDevice(expert_weight.data());
|
||||
scale_b_shuffle_dev_buf.ToDevice(scale_b_shuffle.data());
|
||||
|
||||
const ck_tile::index_t* p_sorted_token_ids_dev =
|
||||
static_cast<ck_tile::index_t*>(sorted_token_ids_dev.GetDeviceBuffer());
|
||||
const ck_tile::index_t* p_expert_ids_dev =
|
||||
static_cast<ck_tile::index_t*>(expert_ids_dev.GetDeviceBuffer());
|
||||
const ck_tile::index_t* p_max_token_id_dev =
|
||||
static_cast<ck_tile::index_t*>(max_token_id_dev.GetDeviceBuffer());
|
||||
const AccDataType* p_sorted_expert_weight_dev =
|
||||
static_cast<AccDataType*>(expert_weight_dev.GetDeviceBuffer());
|
||||
auto scale_b_shuffle_dev_ptr =
|
||||
ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK>{
|
||||
static_cast<float*>(scale_b_shuffle_dev_buf.GetDeviceBuffer()), N / ScaleGranularityN};
|
||||
|
||||
using MoeFlatmmArgs = ck_tile::MoeFlatmmHostArgs<
|
||||
ck_tile::FlatmmScalePointer<-1>,
|
||||
ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK>>;
|
||||
MoeFlatmmArgs gemm_desc{p_sorted_token_ids_dev,
|
||||
p_sorted_expert_weight_dev,
|
||||
p_expert_ids_dev,
|
||||
p_max_token_id_dev,
|
||||
a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_shuffle_dev_buf.GetDeviceBuffer(),
|
||||
c_m_n_dev_buf.GetDeviceBuffer(),
|
||||
num_tokens,
|
||||
experts,
|
||||
topk,
|
||||
1, // k_batch
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
nullptr,
|
||||
scale_b_shuffle_dev_ptr};
|
||||
|
||||
invoke_a16w4_moe_gemm<FlatmmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
kind>(warmup, repeat, gemm_desc);
|
||||
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_tensor.data());
|
||||
|
||||
bool pass{true};
|
||||
if(arg_parser.get_int("validate"))
|
||||
{
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
|
||||
ck_tile::host_tensor_descriptor(IsInputGemm ? num_tokens * topk : num_tokens,
|
||||
outputN,
|
||||
stride_C,
|
||||
is_row_major(CLayout{})));
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
ck_tile::HostTensor<AccDataType> scale_A(
|
||||
ck_tile::HostTensorDescriptor({1, K / ScaleGranularityK}, {1, 1}));
|
||||
|
||||
// scaleA = 1 has no effect on the result
|
||||
ck_tile::FillUniformDistribution<AccDataType>{1.f, 1.f}(scale_A);
|
||||
ck_tile::DeviceMem scale_A_dev_buf(scale_A.get_element_space_size_in_bytes());
|
||||
scale_A_dev_buf.ToDevice(scale_A.data());
|
||||
|
||||
// convert scale_b from e8m0 to float
|
||||
ck_tile::HostTensor<AccDataType> scale_b_float(ck_tile::HostTensorDescriptor(
|
||||
{K * experts / ScaleGranularityK, N / ScaleGranularityN}, {N / ScaleGranularityN, 1}));
|
||||
std::copy(scale_b.begin(), scale_b.end(), scale_b_float.begin());
|
||||
ck_tile::DeviceMem scale_b_float_dev_buf(scale_b_float.get_element_space_size_in_bytes());
|
||||
scale_b_float_dev_buf.ToDevice(scale_b_float.data());
|
||||
|
||||
std::unique_ptr<ck_tile::DeviceMem> c_m_n_ref_buf =
|
||||
std::make_unique<ck_tile::DeviceMem>(c_m_n_tensor.get_element_space_size_in_bytes());
|
||||
c_m_n_ref_buf->SetZero();
|
||||
|
||||
ck_tile::reference_moe_gemm_gpu<
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
static_cast<int>(kind),
|
||||
std::conditional_t<IsInputGemm, LocalSilu, ck_tile::identity>>(
|
||||
p_sorted_token_ids_dev,
|
||||
p_expert_ids_dev,
|
||||
p_max_token_id_dev,
|
||||
static_cast<const ADataType*>(a_m_k_dev_buf.GetDeviceBuffer()),
|
||||
static_cast<const BDataType*>(b_origin_dev_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_m_n_ref_buf->GetDeviceBuffer()),
|
||||
p_sorted_expert_weight_dev,
|
||||
num_tokens,
|
||||
MPerBlock,
|
||||
topk,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
M,
|
||||
1,
|
||||
ScaleGranularityK,
|
||||
static_cast<float*>(scale_A_dev_buf.GetDeviceBuffer()),
|
||||
static_cast<float*>(scale_b_float_dev_buf.GetDeviceBuffer()));
|
||||
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
c_m_n_ref_buf->FromDevice(c_m_n_host_ref.data());
|
||||
|
||||
const float rtol = std::is_same_v<ADataType, ck_tile::half_t> && IsInputGemm ? 1e-3 : 1e-2;
|
||||
const float atol = std::is_same_v<ADataType, ck_tile::half_t> && IsInputGemm ? 1e-3 : 1e-2;
|
||||
|
||||
pass = ck_tile::check_err(
|
||||
c_m_n_tensor, c_m_n_host_ref, "Error: Incorrect results!", rtol, atol);
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol << " Absolute error threshold: " << atol
|
||||
<< std::endl;
|
||||
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
@@ -165,6 +165,10 @@ float moe_gemm(const ck_tile::MoeFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
constexpr int BlockedXDLN_PerWarp = moe_kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_gate_up
|
||||
? 2
|
||||
: 1; // determined by scale shuffle pattern
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
@@ -187,7 +191,8 @@ float moe_gemm(const ck_tile::MoeFlatmmHostArgs<ScaleM, ScaleN>& args,
|
||||
FlatmmConfig::NumWaveGroups,
|
||||
false,
|
||||
1,
|
||||
FlatmmConfig::TiledMMAPermuteN>>;
|
||||
FlatmmConfig::TiledMMAPermuteN,
|
||||
BlockedXDLN_PerWarp>>;
|
||||
|
||||
using CodegenFlatmmPipeline =
|
||||
ck_tile::MoeFlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
|
||||
@@ -170,7 +170,6 @@ int run_moe_gemm_example_with_layouts(int argc,
|
||||
ck_tile::HostTensor<AccDataType> per_channel_scale(
|
||||
ck_tile::HostTensorDescriptor({N * experts}, {1}));
|
||||
|
||||
ck_tile::FillMonotonicSeq<AccDataType>{}(per_token_scale);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{0.f, 1.f}(per_token_scale);
|
||||
ck_tile::FillUniformDistribution<AccDataType>{0.f, 1.f}(per_channel_scale);
|
||||
|
||||
@@ -195,12 +194,12 @@ int run_moe_gemm_example_with_layouts(int argc,
|
||||
|
||||
for(int i = 0; i < sorted_tile_num; i++)
|
||||
{
|
||||
expert_ids.mData[i] = i / (valid_tile_num / experts);
|
||||
expert_ids.mData[i] = i / ((valid_tile_num + experts - 1) / experts);
|
||||
}
|
||||
|
||||
int token_per_tile = (num_tokens * topk + valid_tile_num - 1) / valid_tile_num;
|
||||
// int token_per_tile = num_tokens * topk / valid_tile_num;
|
||||
int tokenid = 0;
|
||||
int tokenid = 0;
|
||||
// sorted_token_ids.mData[0] = 0;
|
||||
for(int i = 0; i < sorted_tile_num * MPerBlock; i++)
|
||||
{
|
||||
@@ -329,8 +328,8 @@ int run_moe_gemm_example_with_layouts(int argc,
|
||||
K, 1 /*kbatch*/, max_accumulated_value);
|
||||
c_m_n_ref_buf->FromDevice(c_m_n_host_ref.data());
|
||||
|
||||
float rtol = 1e-3;
|
||||
float atol = 1e-3;
|
||||
const float rtol = std::is_same_v<ADataType, ck_tile::half_t> && IsInputGemm ? 1e-3 : 1e-2;
|
||||
const float atol = std::is_same_v<ADataType, ck_tile::half_t> && IsInputGemm ? 1e-3 : 1e-2;
|
||||
|
||||
pass = ck_tile::check_err(
|
||||
c_m_n_tensor, c_m_n_host_ref, "Error: Incorrect results!", rtol, atol);
|
||||
|
||||
@@ -82,11 +82,13 @@ __global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
|
||||
AccDataType acc_temp = 0.0;
|
||||
AccDataType acc_up_temp = 0.0;
|
||||
|
||||
float scale_A = 0;
|
||||
float scale_B = 0;
|
||||
float scale_B_up = 0;
|
||||
index_t scale_A_stride = (M + scale_granularity_m - 1) / scale_granularity_m;
|
||||
index_t scale_B_stride = (N + scale_granularity_n - 1) / scale_granularity_n;
|
||||
float scale_A = 0;
|
||||
float scale_B = 0;
|
||||
float scale_B_up = 0;
|
||||
|
||||
index_t scale_A_stride = (M + scale_granularity_m - 1) / scale_granularity_m;
|
||||
index_t scale_B_stride = (N + scale_granularity_n - 1) / scale_granularity_n;
|
||||
index_t scale_B_expert_stride = scale_B_stride * K / scale_granularity_k;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
@@ -101,12 +103,13 @@ __global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
|
||||
// update scale factors
|
||||
scale_A = scale_A_ptr[(gather_token_id / scale_granularity_m) +
|
||||
(k / scale_granularity_k) * scale_A_stride];
|
||||
scale_B = scale_B_ptr[((expert_id * N + col) / scale_granularity_n) +
|
||||
(k / scale_granularity_k) * scale_B_stride];
|
||||
scale_B =
|
||||
scale_B_ptr[expert_id * scale_B_expert_stride + col / scale_granularity_n +
|
||||
(k / scale_granularity_k) * scale_B_stride];
|
||||
if constexpr(MoeGemmKind == 1)
|
||||
scale_B_up =
|
||||
scale_B_ptr[((expert_id * N + col + problem_N) / scale_granularity_n) +
|
||||
(k / scale_granularity_k) * scale_B_stride];
|
||||
scale_B_up = scale_B_ptr[expert_id * scale_B_expert_stride +
|
||||
(col + problem_N) / scale_granularity_n +
|
||||
(k / scale_granularity_k) * scale_B_stride];
|
||||
}
|
||||
|
||||
constexpr index_t packed_size_a = ck_tile::numeric_traits<ADataType>::PackedSize;
|
||||
@@ -138,6 +141,14 @@ __global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
|
||||
else
|
||||
v_a = fp32_val.lo;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a]);
|
||||
if(k % 2 == 1)
|
||||
v_a = fp32_val.hi;
|
||||
else
|
||||
v_a = fp32_val.lo;
|
||||
}
|
||||
else
|
||||
{
|
||||
v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
|
||||
@@ -159,6 +170,22 @@ __global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
|
||||
v_b_up = fp32_val_up.lo;
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(B[b_index / packed_size_b]);
|
||||
if(k % 2 == 1)
|
||||
v_b = fp32_val.hi;
|
||||
else
|
||||
v_b = fp32_val.lo;
|
||||
if constexpr(MoeGemmKind == 1)
|
||||
{
|
||||
const fp32x2_t fp32_val_up = pk_fp4_to_fp32x2(B[b_index_up / packed_size_b]);
|
||||
if(k % 2 == 1)
|
||||
v_b_up = fp32_val_up.hi;
|
||||
else
|
||||
v_b_up = fp32_val_up.lo;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
|
||||
|
||||
@@ -370,6 +370,14 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Ep
|
||||
const auto& d_block_window = gemm_tile_windows.at(I2);
|
||||
const auto& scale_block_window = gemm_tile_windows.at(I4);
|
||||
|
||||
static_assert(ScaleM::GranularityK == ScaleN::GranularityK // have the same granK
|
||||
|| ScaleM::GranularityMN == -1 // or ScaleA is disable
|
||||
|| ScaleN::GranularityMN == -1, // or ScaleB is disable
|
||||
"ScaleM and ScaleN should have the same GranularityK");
|
||||
constexpr bool DoEpiScale =
|
||||
(ScaleM::GranularityMN != -1 && ScaleM::GranularityK == 0) || // per token
|
||||
(ScaleN::GranularityMN != -1 && ScaleN::GranularityK == 0); // per channel
|
||||
|
||||
auto a_block_window_with_distr =
|
||||
ck_tile::make_tile_window(a_block_window.get_bottom_tensor_view(),
|
||||
a_block_window.get_window_lengths(),
|
||||
@@ -383,26 +391,21 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel<TilePartitioner_, FlatmmPipeline_, Ep
|
||||
smem_ptr_pong);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
if constexpr(false && (ScaleM::GranularityMN != -1 && ScaleM::GranularityK == 0) ||
|
||||
(ScaleN::GranularityMN != -1 && ScaleN::GranularityK == 0))
|
||||
if constexpr(DoEpiScale)
|
||||
{
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
EpiloguePipeline{}.template
|
||||
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
|
||||
c_block_window,
|
||||
c_block_tile,
|
||||
d_block_window,
|
||||
smem_ptr_ping,
|
||||
kargs.scale_m_ptr + block_idx_m,
|
||||
kargs.scale_n_ptr + block_idx_n);
|
||||
EpiloguePipeline{}(c_block_window,
|
||||
c_block_tile,
|
||||
d_block_window,
|
||||
smem_ptr_ping,
|
||||
kargs.scale_m_ptr + block_idx_m,
|
||||
kargs.scale_n_ptr + block_idx_n);
|
||||
}
|
||||
else if(UseDefaultScheduler || (get_warp_id() == 0))
|
||||
{
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(I3);
|
||||
EpiloguePipeline{}.template
|
||||
operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
|
||||
c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -80,7 +80,8 @@ enum class MoeFlatmmKind
|
||||
template <typename TilePartitioner_,
|
||||
typename FlatmmPipeline_,
|
||||
typename EpiloguePipeline_,
|
||||
MoeFlatmmKind kind>
|
||||
MoeFlatmmKind kind,
|
||||
typename FusedActivation = element_wise::Silu>
|
||||
struct MoeFlatmmKernel
|
||||
{
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
@@ -101,7 +102,8 @@ struct MoeFlatmmKernel
|
||||
// Below type is actually accumulation data type - the output of block GEMM.
|
||||
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
|
||||
|
||||
using AccDataType = float;
|
||||
using AccDataType = float;
|
||||
using ActivationOp = FusedActivation;
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
|
||||
@@ -114,6 +116,7 @@ struct MoeFlatmmKernel
|
||||
"The size of DsLayout and DsDataType should be the same");
|
||||
|
||||
static constexpr bool IsInputGemm = kind != MoeFlatmmKind::kFFN_gemm2;
|
||||
static constexpr bool IsGateUp = kind == MoeFlatmmKind::kFFN_gemm1_gate_up;
|
||||
|
||||
static constexpr index_t kBlockSize = EpiloguePipeline::kBlockSize;
|
||||
static constexpr index_t kMPerBlock = EpiloguePipeline::kMPerBlock;
|
||||
@@ -128,6 +131,17 @@ struct MoeFlatmmKernel
|
||||
static constexpr index_t kNPerIteration = NPerXdl * NWave;
|
||||
static constexpr index_t kNRepeat = kNPerBlock / kNPerIteration;
|
||||
|
||||
static constexpr int OutputNPerBlock =
|
||||
IsGateUp ? TilePartitioner::NPerBlock / 2 : TilePartitioner::NPerBlock;
|
||||
|
||||
// MXF4_Pipeline only has the of scale B and granularityK is 32
|
||||
static constexpr bool MXFP4_Pipeline = std::is_same_v<BDataType, pk_fp4_t>;
|
||||
static constexpr int MXFP4N_Pack = 2;
|
||||
|
||||
static constexpr int N_Pack = MXFP4_Pipeline ? MXFP4N_Pack : 1;
|
||||
|
||||
static constexpr int WeightPackedSize = numeric_traits<BDataType>::PackedSize;
|
||||
|
||||
template <class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>>
|
||||
struct MoeFlatmmKernelArgs
|
||||
{
|
||||
@@ -405,10 +419,10 @@ struct MoeFlatmmKernel
|
||||
const BDataType* b_flat_ptr,
|
||||
EDataType* e_ptr,
|
||||
const AccDataType* exp_weight_ptr,
|
||||
const int expert_id,
|
||||
const KernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset)
|
||||
{
|
||||
// static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
|
||||
const auto& a_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
@@ -432,9 +446,9 @@ struct MoeFlatmmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
index_t kFlatK = FlatmmPipeline::flatKPerWarp * (splitk_batch_offset.splitted_k /
|
||||
BlockGemmShape::WarpTile::at(number<2>{}));
|
||||
index_t kFlatK = kargs.K * BlockGemmShape::WarpTile::at(I1); // TODO (support splitK)
|
||||
index_t kFlatN = kargs.N * kargs.K / kFlatK;
|
||||
|
||||
const auto& b_flat_tensor_view = [&]() {
|
||||
return make_naive_tensor_view<address_space_enum::global>(
|
||||
b_flat_ptr,
|
||||
@@ -451,7 +465,7 @@ struct MoeFlatmmKernel
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
e_ptr,
|
||||
make_tuple(IsInputGemm ? kargs.NumTokens * kargs.TopK : kargs.NumTokens,
|
||||
kind == MoeFlatmmKind::kFFN_gemm1_gate_up ? kargs.N / 2 : kargs.N),
|
||||
IsGateUp ? kargs.N / 2 : kargs.N),
|
||||
make_tuple(kargs.stride_C, 1),
|
||||
number<EpiloguePipeline::GetVectorSizeC()>{},
|
||||
number<1>{});
|
||||
@@ -461,14 +475,30 @@ struct MoeFlatmmKernel
|
||||
return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
|
||||
e_ptr,
|
||||
make_tuple(IsInputGemm ? kargs.NumTokens * kargs.TopK : kargs.NumToken,
|
||||
kind == MoeFlatmmKind::kFFN_gemm1_gate_up ? kargs.N / 2 : kargs.N),
|
||||
IsGateUp ? kargs.N / 2 : kargs.N),
|
||||
make_tuple(1, kargs.stride_C),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
|
||||
return make_tuple(a_tensor_view, b_flat_tensor_view, c_tensor_view);
|
||||
auto scale_n = kargs.scale_n;
|
||||
constexpr int GranularityK = decltype(scale_n)::GranularityK;
|
||||
|
||||
index_t scale_k = GranularityK == 0 ? 1 : (kargs.K + GranularityK - 1) / GranularityK;
|
||||
index_t FlatScaleK = scale_k * N_Pack * BlockGemmShape::WarpTile::at(I1);
|
||||
index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1);
|
||||
|
||||
using ScaleType = std::conditional_t<MXFP4_Pipeline, e8m0_t, float>;
|
||||
|
||||
const auto scale_b_flat_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
reinterpret_cast<const ScaleType*>(scale_n.ptr) + expert_id * kargs.N * scale_k,
|
||||
make_tuple(FlatScaleN, FlatScaleK),
|
||||
make_tuple(FlatScaleK, 1),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
|
||||
return make_tuple(a_tensor_view, b_flat_tensor_view, c_tensor_view, scale_b_flat_view);
|
||||
}
|
||||
|
||||
template <typename TensorView>
|
||||
@@ -492,14 +522,9 @@ struct MoeFlatmmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
const auto& b_flat_tensor_view = views.at(I1);
|
||||
|
||||
// TODO vector write in for C in ColMajor
|
||||
const auto& c_pad_view = [&]() {
|
||||
const auto& c_tensor_view = views.at(I2);
|
||||
constexpr int OutputNPerBlock = kind == MoeFlatmmKind::kFFN_gemm1_gate_up
|
||||
? TilePartitioner::NPerBlock / 2
|
||||
: TilePartitioner::NPerBlock;
|
||||
const auto& c_tensor_view = views.at(I2);
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return pad_tensor_view(
|
||||
@@ -516,12 +541,13 @@ struct MoeFlatmmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
return make_tuple(a_pad_view, b_flat_tensor_view, c_pad_view);
|
||||
return make_tuple(a_pad_view, views.at(I1), c_pad_view, views.at(I3));
|
||||
}
|
||||
|
||||
template <typename PadView>
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTileWindows(const PadView& views, [[maybe_unused]] const index_t i_m, const index_t i_n)
|
||||
CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views,
|
||||
[[maybe_unused]] const index_t coord_m,
|
||||
const index_t coord_n)
|
||||
{
|
||||
const auto& a_pad_view = views.at(number<0>{});
|
||||
const auto& b_flat_pad_view = views.at(number<1>{});
|
||||
@@ -533,7 +559,7 @@ struct MoeFlatmmKernel
|
||||
return make_tile_window(a_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{},
|
||||
number<TilePartitioner::KPerBlock>{}),
|
||||
{i_m, 0}); // NOTE!
|
||||
{coord_m, 0}); // NOTE!
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -544,25 +570,33 @@ struct MoeFlatmmKernel
|
||||
}
|
||||
}();
|
||||
|
||||
const int problem_N_offset = kind == MoeFlatmmKind::kFFN_gemm1_gate_up ? i_n / 2 : i_n;
|
||||
constexpr bool isNonInterleaveGateUp = !IsGateUp || MXFP4_Pipeline;
|
||||
|
||||
const auto& b_flat_block_window = make_tile_window(
|
||||
b_flat_pad_view,
|
||||
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
|
||||
number<FlatmmPipeline::flatKPerWarp>{}),
|
||||
{static_cast<int>(problem_N_offset / BlockGemmShape::WarpTile::at(I1)), 0});
|
||||
const auto& b_flat_block_window =
|
||||
make_tile_window(b_flat_pad_view,
|
||||
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
|
||||
number<FlatmmPipeline::flatKPerWarp>{}),
|
||||
{static_cast<int>(coord_n / BlockGemmShape::WarpTile::at(I1) /
|
||||
(isNonInterleaveGateUp ? 1 : 2)),
|
||||
0});
|
||||
|
||||
constexpr int OutputNPerBlock = kind == MoeFlatmmKind::kFFN_gemm1_gate_up
|
||||
? TilePartitioner::NPerBlock / 2
|
||||
: TilePartitioner::NPerBlock;
|
||||
const int output_N_offset = IsGateUp ? coord_n / 2 : coord_n;
|
||||
|
||||
auto c_block_window = make_tile_window(
|
||||
c_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<OutputNPerBlock>{}),
|
||||
{0, // offset_m is included when construct C-scatter-window offsets
|
||||
problem_N_offset});
|
||||
output_N_offset});
|
||||
|
||||
return make_tuple(a_block_window, b_flat_block_window, c_block_window);
|
||||
constexpr int GranularityK = 32;
|
||||
|
||||
auto scale_block_window = make_tile_window(
|
||||
views.at(I3),
|
||||
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
|
||||
number<FlatmmPipeline::flatKPerWarp * N_Pack * 4 / GranularityK>{}),
|
||||
{coord_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0});
|
||||
|
||||
return make_tuple(a_block_window, b_flat_block_window, c_block_window, scale_block_window);
|
||||
}
|
||||
|
||||
template <class ScaleM = FlatmmScalePointer<-1>, class ScaleN = FlatmmScalePointer<-1>>
|
||||
@@ -614,16 +648,16 @@ struct MoeFlatmmKernel
|
||||
|
||||
const ADataType* a_ptr =
|
||||
static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
|
||||
const BDataType* b_flat_ptr = static_cast<const BDataType*>(kargs.b_ptr) +
|
||||
splitk_batch_offset.b_k_split_offset +
|
||||
expert_stride * expert_id;
|
||||
const BDataType* b_flat_ptr =
|
||||
static_cast<const BDataType*>(kargs.b_ptr) +
|
||||
(splitk_batch_offset.b_k_split_offset + expert_stride * expert_id) / WeightPackedSize;
|
||||
EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
|
||||
|
||||
const AccDataType* exp_weight_ptr =
|
||||
static_cast<const AccDataType*>(kargs.p_sorted_expert_weights);
|
||||
|
||||
const auto& gemm_tensor_views_tuple = MakeGemmTensorViews(
|
||||
a_ptr, b_flat_ptr, e_ptr, exp_weight_ptr, kargs, splitk_batch_offset);
|
||||
a_ptr, b_flat_ptr, e_ptr, exp_weight_ptr, expert_id, kargs, splitk_batch_offset);
|
||||
const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
|
||||
auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, coord_m, coord_n);
|
||||
@@ -631,26 +665,43 @@ struct MoeFlatmmKernel
|
||||
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(number<0>{});
|
||||
const auto& b_block_window = gemm_tile_windows.at(number<1>{});
|
||||
const auto& a_block_window = gemm_tile_windows.at(I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(I1);
|
||||
const auto& scale_block_window = gemm_tile_windows.at(I3);
|
||||
|
||||
auto a_gather_block_tile =
|
||||
ck_tile::make_tile_scatter_gather(a_block_window.get_bottom_tensor_view(),
|
||||
a_block_window.get_window_lengths(),
|
||||
a_block_window.get_window_origin(),
|
||||
FlatmmPipeline::GetADramTileDistribution(),
|
||||
a_dram_dist,
|
||||
a_offsets); // K DRAM tile window for
|
||||
auto c_block_tile = FlatmmPipeline{}(a_gather_block_tile,
|
||||
b_block_window,
|
||||
number<kind == MoeFlatmmKind::kFFN_gemm1_gate_up>{},
|
||||
num_loop,
|
||||
smem_ptr_ping,
|
||||
smem_ptr_pong);
|
||||
using AccTile = decltype(c_block_tile);
|
||||
|
||||
auto c_block_tile = [&] {
|
||||
if constexpr(MXFP4_Pipeline)
|
||||
{
|
||||
// MXFP4_Pipeline uses gate-up interleave 16 layout for weight
|
||||
// so don't need extra processing
|
||||
return FlatmmPipeline{}(a_gather_block_tile,
|
||||
b_block_window,
|
||||
scale_block_window, // weight scale with granularityK = 32
|
||||
num_loop,
|
||||
smem_ptr_ping,
|
||||
smem_ptr_pong);
|
||||
}
|
||||
else
|
||||
{
|
||||
return FlatmmPipeline{}(a_gather_block_tile,
|
||||
b_block_window,
|
||||
number<IsGateUp>{},
|
||||
num_loop,
|
||||
smem_ptr_ping,
|
||||
smem_ptr_pong);
|
||||
}
|
||||
}();
|
||||
using AccTile = decltype(c_block_tile);
|
||||
|
||||
// Run EpiloguePipeline Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(number<2>{});
|
||||
using ActivationOp = element_wise::Silu;
|
||||
|
||||
{
|
||||
using EpiProblem = typename EpiloguePipeline::Problem;
|
||||
@@ -666,26 +717,34 @@ struct MoeFlatmmKernel
|
||||
constexpr index_t MRepeat = EpiloguePipeline::MRepeat;
|
||||
constexpr index_t NRepeat = EpiloguePipeline::NRepeat;
|
||||
|
||||
constexpr auto lds_block_desc =
|
||||
EpiloguePipeline::template MakeLdsBlockDescriptor<EpiProblem>();
|
||||
static_assert(!IsGateUp || NumNXdlPerWavePerShuffle % 2 == 0);
|
||||
|
||||
constexpr index_t OutputNumNXdlPerWavePerShuffle =
|
||||
IsGateUp ? NumNXdlPerWavePerShuffle / 2 : NumNXdlPerWavePerShuffle;
|
||||
constexpr index_t LDS_NPerIterationShuffle =
|
||||
IsGateUp ? NPerIterationShuffle / 2 : NPerIterationShuffle;
|
||||
|
||||
constexpr auto lds_block_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<LDS_NPerIterationShuffle>{}),
|
||||
make_tuple(number<LDS_NPerIterationShuffle>{}, number<1>{}));
|
||||
|
||||
// EpiloguePipeline::template MakeLdsBlockDescriptor<EpiProblem>();
|
||||
auto o_lds_block = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<ODataType*>(smem_ptr_ping), lds_block_desc);
|
||||
|
||||
auto in_lds_window = make_tile_window(
|
||||
o_lds_block,
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<LDS_NPerIterationShuffle>{}),
|
||||
{0, 0});
|
||||
|
||||
auto out_lds_window = make_tile_window(
|
||||
o_lds_block,
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
|
||||
make_tuple(number<MPerIterationShuffle>{}, number<LDS_NPerIterationShuffle>{}),
|
||||
{0, 0});
|
||||
|
||||
using SFC = space_filling_curve<
|
||||
sequence<kMPerBlock,
|
||||
kind == MoeFlatmmKind::kFFN_gemm1_gate_up ? kNPerBlock / 2 : kNPerBlock>,
|
||||
sequence<0, 1>,
|
||||
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
|
||||
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
|
||||
sequence<0, 1>,
|
||||
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
|
||||
|
||||
constexpr index_t num_access = SFC::get_num_of_access();
|
||||
|
||||
@@ -696,7 +755,7 @@ struct MoeFlatmmKernel
|
||||
using TileEncodingPattern = TileDistributionEncodingPattern2D<
|
||||
kBlockSize,
|
||||
MPerIterationShuffle,
|
||||
NPerIterationShuffle,
|
||||
LDS_NPerIterationShuffle,
|
||||
kind == MoeFlatmmKind::kFFN_gemm2 ? 2 : EpiloguePipeline::GetVectorSizeC(),
|
||||
tile_distribution_pattern::thread_raked,
|
||||
EpiProblem::kNumWaveGroups>;
|
||||
@@ -704,8 +763,24 @@ struct MoeFlatmmKernel
|
||||
constexpr auto dram_tile_distribution =
|
||||
TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
|
||||
constexpr auto LdsTileDistr =
|
||||
make_static_tile_distribution(EpiloguePipeline::MakeLdsDistributionEncode());
|
||||
constexpr auto LdsTileDistr = [&] {
|
||||
if constexpr(IsGateUp)
|
||||
return make_static_tile_distribution(
|
||||
detail::make_embed_tile_distribution_encoding(
|
||||
tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<NumMXdlPerWavePerShuffle, MWave>,
|
||||
// merge two contiguous N
|
||||
sequence<OutputNumNXdlPerWavePerShuffle, NWave>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{},
|
||||
typename CWarpDstr::DstrEncode{}));
|
||||
else
|
||||
return make_static_tile_distribution(
|
||||
EpiloguePipeline::MakeLdsDistributionEncode());
|
||||
}();
|
||||
|
||||
using LDSTileTensor =
|
||||
decltype(make_static_distributed_tensor<AccDataType>(LdsTileDistr));
|
||||
@@ -719,8 +794,8 @@ struct MoeFlatmmKernel
|
||||
constexpr int kM1 = (64 / NPerXdl); // Thr
|
||||
constexpr int kM0 = MPerXdl / kM1 / kM2; // Val
|
||||
|
||||
constexpr int ActVectorSize =
|
||||
c_warp_y_lengths.product() * NumMXdlPerWavePerShuffle * NumNXdlPerWavePerShuffle;
|
||||
constexpr int ActVectorSize = c_warp_y_lengths.product() * NumMXdlPerWavePerShuffle *
|
||||
OutputNumNXdlPerWavePerShuffle;
|
||||
|
||||
const index_t iMWarp = get_warp_id() / NWave;
|
||||
const index_t iNWarp = get_warp_id() - iMWarp * NWave;
|
||||
@@ -737,32 +812,36 @@ struct MoeFlatmmKernel
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Load scales and expert weights
|
||||
//===----------------------------------------------------------------------===//
|
||||
if constexpr(kind == MoeFlatmmKind::kFFN_gemm1_gate_up)
|
||||
if constexpr(!MXFP4_Pipeline)
|
||||
{
|
||||
static_for<0, NRepeat / 2, 1>{}([&](auto i) {
|
||||
vec_scale_B[i] = kargs.scale_n[expert_id * kargs.N + coord_n / 2 +
|
||||
i * NWave * NPerXdl + iNWarp * NPerXdl + iNLane];
|
||||
vec_scale_B[i + NRepeat / 2] =
|
||||
kargs.scale_n[expert_id * kargs.N + kargs.N / 2 + coord_n / 2 +
|
||||
i * NWave * NPerXdl + iNWarp * NPerXdl + iNLane];
|
||||
});
|
||||
if constexpr(IsGateUp)
|
||||
{
|
||||
static_for<0, NRepeat / 2, 1>{}([&](auto i) {
|
||||
vec_scale_B[i * 2] =
|
||||
kargs.scale_n[expert_id * kargs.N + coord_n / 2 + i * NWave * NPerXdl +
|
||||
iNWarp * NPerXdl + iNLane];
|
||||
vec_scale_B[i * 2 + 1] =
|
||||
kargs.scale_n[expert_id * kargs.N + kargs.N / 2 + coord_n / 2 +
|
||||
i * NWave * NPerXdl + iNWarp * NPerXdl + iNLane];
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, NRepeat, 1>{}([&](auto i) {
|
||||
vec_scale_B[i] =
|
||||
kargs.scale_n[expert_id * kargs.N + coord_n + i * NWave * NPerXdl +
|
||||
iNWarp * NPerXdl + iNLane];
|
||||
});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, NRepeat, 1>{}([&](auto i) {
|
||||
vec_scale_B[i] = kargs.scale_n[expert_id * kargs.N + coord_n +
|
||||
i * NWave * NPerXdl + iNWarp * NPerXdl + iNLane];
|
||||
});
|
||||
}
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto i) {
|
||||
static_for<0, kM0, 1>{}([&](auto m0) {
|
||||
static_for<0, kM2, 1>{}([&](auto m2) {
|
||||
index_t M2_offset = m2 + iMLane * kM2 + m0 * kM2 * kM1 + iMWarp * MPerXdl +
|
||||
i * MPerXdl * MWave + coord_m;
|
||||
|
||||
vec_scale_A[i * kM0 * kM2 + m0 * kM2 + m2] =
|
||||
kargs.scale_m[row_to_token_idx(M2_offset)];
|
||||
if constexpr(!MXFP4_Pipeline)
|
||||
vec_scale_A[i * kM0 * kM2 + m0 * kM2 + m2] =
|
||||
kargs.scale_m[row_to_token_idx(M2_offset)];
|
||||
if constexpr(!IsInputGemm)
|
||||
vec_expert_weights[i * kM0 * kM2 + m0 * kM2 + m2] =
|
||||
expert_weights[M2_offset];
|
||||
@@ -770,46 +849,54 @@ struct MoeFlatmmKernel
|
||||
});
|
||||
});
|
||||
|
||||
constexpr int UpAccStride = NRepeat / 2;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pingpong process start
|
||||
//===----------------------------------------------------------------------===//
|
||||
if constexpr(kind == MoeFlatmmKind::kFFN_gemm1_gate_up)
|
||||
if constexpr(IsGateUp)
|
||||
{
|
||||
LDSTileTensor gate_tensor, up_tensor;
|
||||
|
||||
static_assert((NRepeat / NumNXdlPerWavePerShuffle) % 2 == 0);
|
||||
// gate and up are interleaved along NRepeat dimension.
|
||||
static_for<0, OutputNumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
|
||||
gate_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<0, n_xdl>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, 1>{}, c_warp_y_lengths),
|
||||
c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<0 * NumMXdlPerWavePerShuffle, 2 * n_xdl>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, 1>{},
|
||||
c_warp_y_lengths)));
|
||||
|
||||
gate_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(
|
||||
sequence<0 * NumMXdlPerWavePerShuffle, 0 * NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_lengths));
|
||||
up_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<0 * NumMXdlPerWavePerShuffle,
|
||||
0 * NumNXdlPerWavePerShuffle + UpAccStride>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_lengths));
|
||||
up_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<0, n_xdl>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, 1>{}, c_warp_y_lengths),
|
||||
c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<0 * NumMXdlPerWavePerShuffle, 2 * n_xdl + 1>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, 1>{},
|
||||
c_warp_y_lengths)));
|
||||
});
|
||||
|
||||
static_for<0, NumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
|
||||
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
|
||||
constexpr int acc_xdl_offset =
|
||||
(m_xdl + n_xdl * NumMXdlPerWavePerShuffle) * c_warp_y_lengths.product();
|
||||
static_for<0, kM0, 1>{}([&](auto m0) {
|
||||
static_for<0, kM2, 1>{}([&](auto m2) {
|
||||
gate_tensor.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
|
||||
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
|
||||
vec_scale_B[n_xdl];
|
||||
up_tensor.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
|
||||
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
|
||||
vec_scale_B[n_xdl + UpAccStride];
|
||||
if constexpr(!MXFP4_Pipeline)
|
||||
static_for<0, OutputNumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
|
||||
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
|
||||
constexpr int acc_xdl_offset =
|
||||
(m_xdl * OutputNumNXdlPerWavePerShuffle + n_xdl) *
|
||||
c_warp_y_lengths.product();
|
||||
|
||||
static_for<0, kM0, 1>{}([&](auto m0) {
|
||||
static_for<0, kM2, 1>{}([&](auto m2) {
|
||||
gate_tensor
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
|
||||
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
|
||||
vec_scale_B[2 * n_xdl];
|
||||
up_tensor.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
|
||||
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
|
||||
vec_scale_B[2 * n_xdl + 1];
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, ActVectorSize, 1>{}([&](auto idx) {
|
||||
ActivationOp{}(gate_tensor.get_thread_buffer().at(idx),
|
||||
@@ -830,16 +917,18 @@ struct MoeFlatmmKernel
|
||||
static_for<0, NumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
|
||||
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
|
||||
constexpr int acc_xdl_offset =
|
||||
(m_xdl + n_xdl * NumMXdlPerWavePerShuffle) * c_warp_y_lengths.product();
|
||||
(m_xdl * NumNXdlPerWavePerShuffle + n_xdl) * c_warp_y_lengths.product();
|
||||
static_for<0, kM0, 1>{}([&](auto m0) {
|
||||
static_for<0, kM2, 1>{}([&](auto m2) {
|
||||
if constexpr(!IsInputGemm)
|
||||
lds_tile[0]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
|
||||
vec_expert_weights[m_xdl * kM0 * kM2 + m0 * kM2 + m2];
|
||||
lds_tile[0].get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
|
||||
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
|
||||
vec_scale_B[n_xdl];
|
||||
if constexpr(!MXFP4_Pipeline)
|
||||
lds_tile[0]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
|
||||
vec_scale_A[m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
|
||||
vec_scale_B[n_xdl];
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -875,51 +964,62 @@ struct MoeFlatmmKernel
|
||||
|
||||
if constexpr(iAccess < num_access - 1)
|
||||
{
|
||||
if constexpr(kind == MoeFlatmmKind::kFFN_gemm1_gate_up)
|
||||
if constexpr(IsGateUp)
|
||||
{
|
||||
LDSTileTensor gate_tensor, up_tensor;
|
||||
|
||||
gate_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter_next * NumMXdlPerWavePerShuffle,
|
||||
nIter_next * NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(
|
||||
sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_lengths));
|
||||
up_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(
|
||||
sequence<mIter_next * NumMXdlPerWavePerShuffle,
|
||||
nIter_next * NumNXdlPerWavePerShuffle + UpAccStride>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(
|
||||
sequence<NumMXdlPerWavePerShuffle, NumNXdlPerWavePerShuffle>{},
|
||||
c_warp_y_lengths));
|
||||
static_for<0, OutputNumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
|
||||
gate_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<0, n_xdl>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, 1>{},
|
||||
c_warp_y_lengths),
|
||||
c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter_next * NumMXdlPerWavePerShuffle,
|
||||
nIter_next * NumNXdlPerWavePerShuffle +
|
||||
2 * n_xdl>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, 1>{},
|
||||
c_warp_y_lengths)));
|
||||
|
||||
static_for<0, NumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
|
||||
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
|
||||
constexpr int acc_xdl_offset =
|
||||
(m_xdl + n_xdl * NumMXdlPerWavePerShuffle) *
|
||||
c_warp_y_lengths.product();
|
||||
static_for<0, kM0, 1>{}([&](auto m0) {
|
||||
static_for<0, kM2, 1>{}([&](auto m2) {
|
||||
gate_tensor
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
|
||||
vec_scale_A[mIter_next * NumMXdlPerWavePerShuffle *
|
||||
kM0 * kM2 +
|
||||
m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
|
||||
vec_scale_B[nIter_next * NumNXdlPerWavePerShuffle +
|
||||
n_xdl];
|
||||
up_tensor
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
|
||||
vec_scale_A[mIter_next * NumMXdlPerWavePerShuffle *
|
||||
kM0 * kM2 +
|
||||
m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
|
||||
vec_scale_B[nIter_next * NumNXdlPerWavePerShuffle +
|
||||
n_xdl + UpAccStride];
|
||||
up_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<0, n_xdl>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, 1>{},
|
||||
c_warp_y_lengths),
|
||||
c_block_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter_next * NumMXdlPerWavePerShuffle,
|
||||
nIter_next * NumNXdlPerWavePerShuffle +
|
||||
2 * n_xdl + 1>{},
|
||||
c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<NumMXdlPerWavePerShuffle, 1>{},
|
||||
c_warp_y_lengths)));
|
||||
});
|
||||
|
||||
if constexpr(!MXFP4_Pipeline)
|
||||
static_for<0, OutputNumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
|
||||
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
|
||||
constexpr int acc_xdl_offset =
|
||||
(m_xdl * OutputNumNXdlPerWavePerShuffle + n_xdl) *
|
||||
c_warp_y_lengths.product();
|
||||
static_for<0, kM0, 1>{}([&](auto m0) {
|
||||
static_for<0, kM2, 1>{}([&](auto m2) {
|
||||
gate_tensor.get_thread_buffer()[acc_xdl_offset +
|
||||
m0 * kM2 + m2] *=
|
||||
vec_scale_A[mIter_next * NumMXdlPerWavePerShuffle *
|
||||
kM0 * kM2 +
|
||||
m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
|
||||
vec_scale_B[nIter_next * NumNXdlPerWavePerShuffle +
|
||||
2 * n_xdl];
|
||||
up_tensor.get_thread_buffer()[acc_xdl_offset +
|
||||
m0 * kM2 + m2] *=
|
||||
vec_scale_A[mIter_next * NumMXdlPerWavePerShuffle *
|
||||
kM0 * kM2 +
|
||||
m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
|
||||
vec_scale_B[nIter_next * NumNXdlPerWavePerShuffle +
|
||||
2 * n_xdl + 1];
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
static_for<0, ActVectorSize, 1>{}([&](auto idx) {
|
||||
ActivationOp{}(gate_tensor.get_thread_buffer().at(idx),
|
||||
gate_tensor.get_thread_buffer().at(idx));
|
||||
@@ -941,7 +1041,7 @@ struct MoeFlatmmKernel
|
||||
static_for<0, NumNXdlPerWavePerShuffle, 1>{}([&](auto n_xdl) {
|
||||
static_for<0, NumMXdlPerWavePerShuffle, 1>{}([&](auto m_xdl) {
|
||||
constexpr int acc_xdl_offset =
|
||||
(m_xdl + n_xdl * NumMXdlPerWavePerShuffle) *
|
||||
(m_xdl * NumNXdlPerWavePerShuffle + n_xdl) *
|
||||
c_warp_y_lengths.product();
|
||||
static_for<0, kM0, 1>{}([&](auto m0) {
|
||||
static_for<0, kM2, 1>{}([&](auto m2) {
|
||||
@@ -951,13 +1051,15 @@ struct MoeFlatmmKernel
|
||||
m2] *= vec_expert_weights
|
||||
[mIter_next * NumMXdlPerWavePerShuffle * kM0 * kM2 +
|
||||
m_xdl * kM0 * kM2 + m0 * kM2 + m2];
|
||||
lds_tile[write_stage]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 + m2] *=
|
||||
vec_scale_A[mIter_next * NumMXdlPerWavePerShuffle *
|
||||
kM0 * kM2 +
|
||||
m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
|
||||
vec_scale_B[nIter_next * NumNXdlPerWavePerShuffle +
|
||||
n_xdl];
|
||||
if constexpr(!MXFP4_Pipeline)
|
||||
lds_tile[write_stage]
|
||||
.get_thread_buffer()[acc_xdl_offset + m0 * kM2 +
|
||||
m2] *=
|
||||
vec_scale_A[mIter_next * NumMXdlPerWavePerShuffle *
|
||||
kM0 * kM2 +
|
||||
m_xdl * kM0 * kM2 + m0 * kM2 + m2] *
|
||||
vec_scale_B[nIter_next * NumNXdlPerWavePerShuffle +
|
||||
n_xdl];
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -582,19 +582,20 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
if constexpr(!IsGateUpMode || nIter < NIterPerWarp / 2)
|
||||
{
|
||||
if constexpr(!IsGateUpMode)
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{(nIter - NIterPerWarp / 2) * NFlatPerBlockPerIter + up_weight_stride,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
if constexpr(nIter % 2 == 0)
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
else
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter / 2 * NFlatPerBlockPerIter + up_weight_stride,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
}
|
||||
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
});
|
||||
});
|
||||
@@ -637,18 +638,20 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
if constexpr(!IsGateUpMode || nIter < NIterPerWarp / 2)
|
||||
{
|
||||
if constexpr(!IsGateUpMode)
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{(nIter - NIterPerWarp / 2) * NFlatPerBlockPerIter + up_weight_stride,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
if constexpr(nIter % 2 == 0)
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
else
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter / 2 * NFlatPerBlockPerIter + up_weight_stride,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
}
|
||||
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
@@ -723,18 +726,20 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
if constexpr(!IsGateUpMode || nIter < NIterPerWarp / 2)
|
||||
{
|
||||
if constexpr(!IsGateUpMode)
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{(nIter - NIterPerWarp / 2) * NFlatPerBlockPerIter + up_weight_stride,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
if constexpr(nIter % 2 == 0)
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
else
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter / 2 * NFlatPerBlockPerIter + up_weight_stride,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
}
|
||||
|
||||
b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
@@ -812,18 +817,20 @@ struct MoeFlatmmPipelineAGmemBGmemCRegV1
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
|
||||
|
||||
if constexpr(!IsGateUpMode || nIter < NIterPerWarp / 2)
|
||||
{
|
||||
if constexpr(!IsGateUpMode)
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{(nIter - NIterPerWarp / 2) * NFlatPerBlockPerIter + up_weight_stride,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
if constexpr(nIter % 2 == 0)
|
||||
move_tile_window(
|
||||
b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter / 2 * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
|
||||
else
|
||||
move_tile_window(b_flat_dram_windows(nIter)(kIter),
|
||||
{nIter / 2 * NFlatPerBlockPerIter + up_weight_stride,
|
||||
kIter * KFlatPerBlockPerIter});
|
||||
}
|
||||
|
||||
b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter));
|
||||
|
||||
Reference in New Issue
Block a user