[CK_TILE] support split-k a16w4 gemm1 (#3389)

* initial version to support moe gemm1 split-k

* add missing args

* fix build warning

* update reference

* for split-k disable bias and weight

* remove debug log

* fix format

* fix div by zero errors

* fix cmake config

* update

* resolve conflicts

* remove useless changes

* reformat

* fix

* remove useless changes

* fix ci

---------

Co-authored-by: lalala-sh <Jiaxing.Wen@amd.com>
Co-authored-by: root <root@smci355-ccs-aus-m01-25.cs-aus.dcgpu>

[ROCm/composable_kernel commit: dae85ead64]
This commit is contained in:
yadaish
2025-12-29 23:05:35 +08:00
committed by GitHub
parent 3772cf9dd4
commit fc3ffa0d75
11 changed files with 136 additions and 78 deletions

View File

@@ -31,13 +31,14 @@ if(has_supported_gpu)
add_executable(tile_example_grouped_flatmm grouped_flatmm.cpp)
target_compile_options(tile_example_grouped_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
if (GPU_TARGETS MATCHES "gfx95")
if(GPU_TARGETS MATCHES "gfx95" OR GPU_TARGETS MATCHES "gfx94")
add_executable(tile_example_mixed_prec_flatmm mixed_prec/mixed_prec_flatmm.cpp)
target_compile_options(tile_example_mixed_prec_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
add_executable(tile_example_a16w4_moe_flatmm mixed_prec/a16w4_moe_flatmm.cpp)
target_compile_options(tile_example_a16w4_moe_flatmm PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})
endif()
if (GPU_TARGETS MATCHES "gfx95")
include(mxgemm/mx_flatmm_instance.cmake)
mx_flatmm_instance_generate(EXAMPLE_MX_FLATMM_FILES)
message(STATUS "Generated MX FlatMM kernel files: ${EXAMPLE_MX_FLATMM_FILES}")

View File

@@ -8,7 +8,7 @@
// GEMM config with 16x16 warp tile
struct A16W4_FlatmmConfig16
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t M_Tile = 64;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 256;

View File

@@ -191,13 +191,15 @@ float a16w4_moe_gemm(const MoeFlatmmHostArgs& args, const ck_tile::stream_config
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:" << CodegenFlatmmShape::GetName() << "\n"
std::cout << "Launching kernel " << Kernel::GetName() << "\n"
<< "with args:" << CodegenFlatmmShape::GetName() << "\n"
<< "Shape: " << CodegenFlatmmShape::GetName() << "\n"
<< "problem: " << CodegenPipelineProblem::GetName() << "\n"
<< "pipeline: " << CodegenFlatmmPipeline::GetName() << "\n"
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
<< "\n"
<< "k_batch: " << kargs.k_batch << std::endl;
}
if(s.flush_cache_)
@@ -471,10 +473,33 @@ int run_a16w4_moe_flatmm_example(int argc, char* argv[])
throw std::runtime_error("Unsupported precision type for gemm2!");
}
}
else if(gemm_kind == "gemm1_split_k")
{
if(mixed_prec == "fp16xfp4")
{
return run_a16w4_moe_gemm_example_with_layouts<
ck_tile::half_t,
ck_tile::pk_fp4_t,
FlatmmConfig,
ck_tile::MoeFlatmmKind::kFFN_gemm1_split_k>(argc, argv, Row{}, Col{}, Row{});
}
else if(mixed_prec == "bf16xfp4")
{
return run_a16w4_moe_gemm_example_with_layouts<
ck_tile::bfloat16_t,
ck_tile::pk_fp4_t,
FlatmmConfig,
ck_tile::MoeFlatmmKind::kFFN_gemm1_split_k>(argc, argv, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported precision type for gemm1_split_k!");
}
}
else
{
throw std::runtime_error("Unrecoginized gemm_kind parameter, only accept value "
"[gemm1_gate_up | gemm2]");
"[gemm1_gate_up | gemm1_split_k | gemm2]");
}
}
else

View File

@@ -13,7 +13,7 @@
// GEMM config with 16x16 warp tile
struct A16W4_FlatmmConfig16
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t M_Tile = 32;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 256;
@@ -69,7 +69,7 @@ auto create_args(int argc, char* argv[])
.insert("c_layout", "R", "C tensor data layout - Row by default.")
.insert("gemm_kind",
"gemm1_gate_up",
"Gemm kind in FFN network [gemm1_gate_up | gemm2] - "
"Gemm kind in FFN network [gemm1_gate_up | gemm2 | gemm1_split_k] - "
"gemm1_gate_up by default.")
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
@@ -80,7 +80,8 @@ auto create_args(int argc, char* argv[])
.insert("warp_tile",
"0",
"0: 16x16, 1: 16x16 (950 only, may use a larger tile than warp_tile=0)")
.insert("repeat", "10", "number of iterations to benchmark the kernel.");
.insert("repeat", "10", "number of iterations to benchmark the kernel.")
.insert("k_batch", "1", "parallism to control splik-k.");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);

View File

@@ -67,9 +67,12 @@ int run_a16w4_moe_gemm_example_with_layouts(int argc,
return -1;
};
using ADataType = PrecActType;
using BDataType = PrecWeightType;
using CDataType = PrecActType;
using ADataType = PrecActType;
using BDataType = PrecWeightType;
using ADataType = PrecActType;
using BDataType = PrecWeightType;
using CDataType =
std::conditional_t<kind == ck_tile::MoeFlatmmKind::kFFN_gemm1_split_k, float, PrecActType>;
using AccDataType = float;
using ScaleType = ck_tile::e8m0_t;
@@ -88,6 +91,7 @@ int run_a16w4_moe_gemm_example_with_layouts(int argc,
const ck_tile::index_t warmup = arg_parser.get_int("warmup");
const ck_tile::index_t repeat = arg_parser.get_int("repeat");
const ck_tile::index_t experts = arg_parser.get_int("experts");
const ck_tile::index_t k_batch = arg_parser.get_int("k_batch");
// TODO: replace the magic declaration
const ck_tile::index_t MPerBlock = FlatmmConfig::M_Tile;
@@ -231,14 +235,15 @@ int run_a16w4_moe_gemm_example_with_layouts(int argc,
static_cast<AccDataType*>(expert_weight_dev.GetDeviceBuffer());
auto scale_b_shuffle_dev_ptr =
ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK>{
static_cast<float*>(scale_b_shuffle_dev_buf.GetDeviceBuffer()), N / ScaleGranularityN};
ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK, ScaleType>{
static_cast<ScaleType*>(scale_b_shuffle_dev_buf.GetDeviceBuffer()),
N / ScaleGranularityN};
auto exp_bias_dev_ptr = ck_tile::FlatmmScalePointer<1>{
static_cast<float*>(expert_bias_dev.GetDeviceBuffer()), experts * N};
using MoeFlatmmArgs = ck_tile::MoeFlatmmHostArgs<
ck_tile::FlatmmScalePointer<-1>,
ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK>,
ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK, ScaleType>,
ck_tile::FlatmmScalePointer<1>>;
MoeFlatmmArgs gemm_desc{p_sorted_token_ids_dev,
p_sorted_expert_weight_dev,
@@ -250,7 +255,7 @@ int run_a16w4_moe_gemm_example_with_layouts(int argc,
num_tokens,
experts,
topk,
1, // k_batch
k_batch, // k_batch
M,
N,
K,

View File

@@ -85,8 +85,9 @@ int run_mixed_prec_flatmm_with_layouts(int argc,
c_rslt_host.SetZero();
scale_b_dev_buf.ToDevice(scale_b_shuffle.data());
auto scale_b_dev_ptr = ck_tile::FlatmmScalePointer<DequantGranularityN, DequantGranularityK>{
static_cast<float*>(scale_b_dev_buf.GetDeviceBuffer()), N / DequantGranularityN};
auto scale_b_dev_ptr =
ck_tile::FlatmmScalePointer<DequantGranularityN, DequantGranularityK, ScaleType>{
static_cast<ScaleType*>(scale_b_dev_buf.GetDeviceBuffer()), N / DequantGranularityN};
invoke_mixed_prec_flatmm<FlatmmConfig,
ADataType,

View File

@@ -25,14 +25,16 @@ using BF16 = ck_tile::bf16_t;
using ROW = ck_tile::tensor_layout::gemm::RowMajor;
using COL = ck_tile::tensor_layout::gemm::ColumnMajor;
using ScaleType = ck_tile::e8m0_t;
inline constexpr auto ODD = ck_tile::TailNumber::Odd;
inline constexpr auto EVEN = ck_tile::TailNumber::Even;
inline constexpr int ScaleGranularityM = 1;
inline constexpr int ScaleGranularityN = 1;
inline constexpr int ScaleGranularityK = 32;
using ScaleM = ck_tile::FlatmmScalePointer<ScaleGranularityM, ScaleGranularityK>;
using ScaleN = ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK>;
using ScaleM = ck_tile::FlatmmScalePointer<ScaleGranularityM, ScaleGranularityK, ScaleType>;
using ScaleN = ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK, ScaleType>;
template float mx_flatmm_calc<FLATMM_CONFIG,
A_DATA_TYPE,

View File

@@ -105,10 +105,12 @@ int run_mx_flatmm_with_layouts(int argc,
scale_a_dev_buf.ToDevice(scale_a_shuffled.data());
scale_b_dev_buf.ToDevice(scale_b_shuffled.data());
auto scale_a_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityM, ScaleGranularityK>{
static_cast<float*>(scale_a_dev_buf.GetDeviceBuffer()), M / ScaleGranularityM};
auto scale_b_dev_ptr = ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK>{
static_cast<float*>(scale_b_dev_buf.GetDeviceBuffer()), N / ScaleGranularityN};
auto scale_a_dev_ptr =
ck_tile::FlatmmScalePointer<ScaleGranularityM, ScaleGranularityK, ScaleType>{
static_cast<ScaleType*>(scale_a_dev_buf.GetDeviceBuffer()), M / ScaleGranularityM};
auto scale_b_dev_ptr =
ck_tile::FlatmmScalePointer<ScaleGranularityN, ScaleGranularityK, ScaleType>{
static_cast<ScaleType*>(scale_b_dev_buf.GetDeviceBuffer()), N / ScaleGranularityN};
invoke_mx_flatmm<FlatmmConfig,
ADataType,

View File

@@ -18,7 +18,7 @@ template <typename ADataType,
typename LayoutA,
typename LayoutB,
typename LayoutC,
int MoeGemmKind = 0, // 0: gemm1_gate_only, 1: gemm1_gate_up, 2: gemm2
int MoeGemmKind = 0, // 0: gemm1_gate_only, 1: gemm1_gate_up, 2: gemm2, 3:gemm1_split_k
typename ActivationOp = identity>
__global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
const ck_tile::index_t* p_sorted_expert_ids_,
@@ -43,10 +43,11 @@ __global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
float* scale_B_ptr,
float* expert_bias_ptr)
{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int problem_N = MoeGemmKind == 1 ? N / 2 : N;
int row = idx / problem_N; // Compute row index
int col = idx % problem_N; // Compute column index
constexpr auto is_split_k = MoeGemmKind == 3;
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int problem_N = MoeGemmKind == 1 ? N / 2 : N;
int row = idx / problem_N; // Compute row index
int col = idx % problem_N; // Compute column index
index_t gather_token_id = 0;
index_t scatter_token_id = 0;
@@ -203,7 +204,7 @@ __global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
acc_up += acc_up_temp * scale_A * scale_B_up;
float bias = 0.f, bias_up = 0.f;
if(expert_bias_ptr != nullptr)
if(expert_bias_ptr != nullptr && !is_split_k)
{
bias = expert_bias_ptr[expert_id * N + col];
if constexpr(MoeGemmKind == 1)
@@ -221,23 +222,24 @@ __global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
else
{
// moe gemm2 don't use activation.
CDataType res = ck_tile::type_convert<CDataType>((acc + bias) * expert_weight_ptr[row]);
using ResV2Type = std::conditional_t<std::is_same_v<CDataType, ck_tile::half_t>,
ck_tile::fp16x2_t,
ck_tile::bf16x2_t>;
ResV2Type add_v{0, 0};
auto weight =
is_split_k ? ck_tile::type_convert<AccDataType>(1.0f) : expert_weight_ptr[row];
CDataType res = ck_tile::type_convert<CDataType>((acc + bias) * weight);
thread_buffer<CDataType, 2> add_v = 0;
if(c_index % 2)
{
// result is the second value of fp16 pair.
add_v.y = res;
add_v.template get_as<CDataType>()[1] = res;
}
else
{
// result is the first value of fp16 pair.
add_v.x = res;
add_v.template get_as<CDataType>()[0] = res;
}
// mask last bit to make sure atomicAdd pointer is aligned of DWORD.
atomic_add<ResV2Type>(reinterpret_cast<ResV2Type*>(C + (c_index & 0xffff'fffe)), add_v);
atomic_add_g<CDataType, 2>(reinterpret_cast<CDataType*>(C + (c_index & 0xffff'fffe)),
add_v);
}
}
}
@@ -249,7 +251,7 @@ template <typename ADataType,
typename LayoutA,
typename LayoutB,
typename LayoutC,
int MoeGemmKind = 0, // 0: gemm1_gate_only, 1: gemm1_gate_up, 2: gemm2
int MoeGemmKind = 0, // 0: gemm1_gate_only, 1: gemm1_gate_up, 2: gemm2, 3:gemm1_split_k
typename ActivationOp = identity>
void reference_moe_gemm_gpu(const index_t* p_sorted_token_ids_,
const index_t* p_sorted_expert_ids_,

View File

@@ -28,17 +28,18 @@ struct FlatmmProblem
index_t stride_C;
};
template <int SharedGranularityMN, int SharedGranularityK = 0>
template <int SharedGranularityMN, int SharedGranularityK = 0, typename ScaleType_ = float>
struct FlatmmScalePointer
{
using ScaleType = ScaleType_;
static constexpr int GranularityMN = SharedGranularityMN;
static constexpr int GranularityK = SharedGranularityK;
const float* ptr;
const ScaleType* ptr;
CK_TILE_HOST_DEVICE FlatmmScalePointer() = default;
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_) : ptr(ptr_) {}
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_, [[maybe_unused]] index_t length_)
CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType* ptr_) : ptr(ptr_) {}
CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType* ptr_, [[maybe_unused]] index_t length_)
: ptr(ptr_)
{
}
@@ -57,23 +58,24 @@ struct FlatmmScalePointer
return ret;
}
CK_TILE_HOST_DEVICE float operator[](index_t i) const = delete;
CK_TILE_HOST_DEVICE ScaleType operator[](index_t i) const = delete;
};
template <int SharedGranularityMN>
struct FlatmmScalePointer<SharedGranularityMN, 0>
template <int SharedGranularityMN, typename ScaleType_>
struct FlatmmScalePointer<SharedGranularityMN, 0, ScaleType_>
{
using ScaleType = ScaleType_;
static constexpr int GranularityMN = SharedGranularityMN;
static constexpr int GranularityK = 0;
static_assert(GranularityMN != 0);
const float* ptr;
const ScaleType* ptr;
index_t length;
CK_TILE_HOST_DEVICE FlatmmScalePointer() = default;
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_) : ptr(ptr_), length(1) {}
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_, index_t length_)
CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType* ptr_) : ptr(ptr_), length(1) {}
CK_TILE_HOST_DEVICE FlatmmScalePointer(const ScaleType* ptr_, index_t length_)
: ptr(ptr_), length(length_)
{
}
@@ -94,7 +96,7 @@ struct FlatmmScalePointer<SharedGranularityMN, 0>
return ret;
}
CK_TILE_HOST_DEVICE float operator[](index_t i) const
CK_TILE_HOST_DEVICE ScaleType operator[](index_t i) const
{
// with additional oob check
if constexpr(GranularityMN == 1)
@@ -105,23 +107,24 @@ struct FlatmmScalePointer<SharedGranularityMN, 0>
};
// shared granularityMN = -1 means no scale
template <>
struct FlatmmScalePointer<-1, 0>
template <typename ScaleType_>
struct FlatmmScalePointer<-1, 0, ScaleType_>
{
using ScaleType = ScaleType_;
static constexpr int GranularityMN = -1;
static constexpr int GranularityK = 0;
const float* ptr = nullptr;
const ScaleType* ptr = nullptr;
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer() = default;
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const float*) {}
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const float*, index_t) {}
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const ScaleType*) {}
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const ScaleType*, index_t) {}
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer operator+(index_t) const
{
return FlatmmScalePointer{};
}
CK_TILE_HOST_DEVICE constexpr float operator[](index_t) const
CK_TILE_HOST_DEVICE constexpr ScaleType operator[](index_t) const
{
return 1; // alway return 1, it doesn't change the result
}

View File

@@ -132,6 +132,7 @@ enum class MoeFlatmmKind
kFFN_gemm1_gate_only,
kFFN_gemm1_gate_up,
kFFN_gemm2,
kFFN_gemm1_split_k,
};
namespace moe {
@@ -222,8 +223,10 @@ struct MoeFlatmmKernel
static_assert(DsLayout::size() == DsDataType::size(),
"The size of DsLayout and DsDataType should be the same");
static constexpr bool IsInputGemm = kind != MoeFlatmmKind::kFFN_gemm2;
static constexpr bool IsGateUp = kind == MoeFlatmmKind::kFFN_gemm1_gate_up;
static constexpr bool IsInputGemm = kind != MoeFlatmmKind::kFFN_gemm2;
static constexpr bool IsGateUp = kind == MoeFlatmmKind::kFFN_gemm1_gate_up;
static constexpr bool IsGemm1SplitK = kind == MoeFlatmmKind::kFFN_gemm1_split_k;
static constexpr bool IsBShuffled = true;
// static constexpr index_t kBlockSize = EpiloguePipeline::kBlockSize;
static constexpr index_t kMPerBlock = EpiloguePipeline::kMPerBlock;
@@ -395,15 +398,6 @@ struct MoeFlatmmKernel
a_k_split_offset = k_id * KRead * kargs.stride_A;
}
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_k_split_offset = k_id * KRead * kargs.stride_B;
}
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
b_k_split_offset = k_id * KRead;
}
if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
{
splitted_k = KRead;
@@ -412,6 +406,22 @@ struct MoeFlatmmKernel
{
splitted_k = kargs.K - KRead * (kargs.k_batch - 1);
}
if constexpr(IsBShuffled)
{
b_k_split_offset = k_id * splitted_k * NPerXdl;
}
else
{
if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_k_split_offset = k_id * KRead * kargs.stride_B;
}
else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
b_k_split_offset = k_id * KRead;
}
}
}
index_t a_k_split_offset;
@@ -573,15 +583,16 @@ struct MoeFlatmmKernel
return DTesnorIsValid;
}
template <memory_operation_enum DstInMemOp = IsInputGemm ? memory_operation_enum::set
: memory_operation_enum::atomic_add,
template <memory_operation_enum DstInMemOp = (IsInputGemm && !IsGemm1SplitK)
? memory_operation_enum::set
: memory_operation_enum::atomic_add,
typename KernelArgs>
CK_TILE_DEVICE static auto
MakeGemmTensorViews(const ADataType* a_ptr,
const BDataType* b_flat_ptr,
EDataType* e_ptr,
[[maybe_unused]] const AccDataType* exp_weight_ptr,
const int expert_id,
[[maybe_unused]] const int expert_id,
const KernelArgs& kargs,
const SplitKBatchOffset& splitk_batch_offset)
{
@@ -742,13 +753,13 @@ struct MoeFlatmmKernel
{
index_t scale_k =
BGranularityK == 0 ? 1 : (kargs.K + BGranularityK - 1) / BGranularityK;
const auto scale_k_offset =
(splitk_batch_offset.b_k_split_offset / BGranularityK) * K_Pack;
index_t FlatScaleK = scale_k * N_Pack * BlockGemmShape::WarpTile::at(I1);
index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1);
using ScaleType = e8m0_t;
return make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const ScaleType*>(scale_n.ptr) + expert_id * kargs.N * scale_k,
scale_n.ptr + expert_id * kargs.N * scale_k + scale_k_offset,
make_tuple(FlatScaleN - kargs.n_padded_zeros / NPerXdl / N_Pack, FlatScaleK),
make_tuple(FlatScaleK, 1),
number<8>{},
@@ -1386,11 +1397,16 @@ struct MoeFlatmmKernel
if constexpr(!BMXFP4_Pipeline)
lds_tile[lds_stage].get_thread_buffer()[idx] *=
epi_scale_m[idx] * epi_scale_n[idx];
if constexpr(EnableBias)
lds_tile[lds_stage].get_thread_buffer()[idx] += epi_exp_bias[idx];
if constexpr(!IsInputGemm)
lds_tile[lds_stage].get_thread_buffer()[idx] *= epi_exp_weight[idx];
else // for mlp1 gate-only
if(kind !=
MoeFlatmmKind::kFFN_gemm1_split_k) // disable weight and bias for split-k
{
if constexpr(EnableBias)
lds_tile[lds_stage].get_thread_buffer()[idx] += epi_exp_bias[idx];
if constexpr(!IsInputGemm)
lds_tile[lds_stage].get_thread_buffer()[idx] *= epi_exp_weight[idx];
}
if constexpr(kind ==
MoeFlatmmKind::kFFN_gemm1_gate_only) // for mlp1 gate-only
lds_tile[lds_stage].get_thread_buffer()[idx] =
ActivationOp{}(lds_tile[lds_stage].get_thread_buffer()[idx]);
});