Add bias for f16xf4 moe_flatmm

This commit is contained in:
Feng Shijie
2025-08-28 08:02:50 +00:00
parent dd6539f366
commit 5c484a5672
5 changed files with 179 additions and 92 deletions

View File

@@ -40,7 +40,8 @@ __global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
index_t scale_granularity_n,
index_t scale_granularity_k,
float* scale_A_ptr,
float* scale_B_ptr)
float* scale_B_ptr,
float* expert_bias_ptr)
{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int problem_N = MoeGemmKind == 1 ? N / 2 : N;
@@ -200,18 +201,26 @@ __global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
acc += acc_temp * scale_A * scale_B;
acc_up += acc_up_temp * scale_A * scale_B_up;
float bias = 0.f, bias_up = 0.f;
if(expert_bias_ptr != nullptr)
{
bias = expert_bias_ptr[expert_id * N + col];
if constexpr(MoeGemmKind == 1)
bias_up = expert_bias_ptr[expert_id * N + col + problem_N];
}
int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
? scatter_token_id * strideC + col
: col * strideC + scatter_token_id;
if constexpr(MoeGemmKind < 2)
{
C[c_index] = ck_tile::type_convert<CDataType>(
ActivationOp{}(acc, MoeGemmKind == 1 ? acc_up : 1));
ActivationOp{}(acc + bias, MoeGemmKind == 1 ? acc_up + bias_up : 1));
}
else
{
// moe gemm2 don't use activation.
CDataType res = ck_tile::type_convert<CDataType>(acc * expert_weight_ptr[row]);
CDataType res = ck_tile::type_convert<CDataType>((acc + bias) * expert_weight_ptr[row]);
using ResV2Type = std::conditional_t<std::is_same_v<CDataType, ck_tile::half_t>,
ck_tile::fp16x2_t,
ck_tile::bf16x2_t>;
@@ -261,7 +270,8 @@ void reference_moe_gemm_gpu(const index_t* p_sorted_token_ids_,
index_t scale_granularity_n,
index_t scale_granularity_k,
float* scale_A_ptr,
float* scale_B_ptr)
float* scale_B_ptr,
float* exp_bias = nullptr)
{
int problem_N = MoeGemmKind == 1 ? N / 2 : N;
int totalElements = M * problem_N;
@@ -296,7 +306,8 @@ void reference_moe_gemm_gpu(const index_t* p_sorted_token_ids_,
scale_granularity_n,
scale_granularity_k,
scale_A_ptr,
scale_B_ptr);
scale_B_ptr,
exp_bias);
return;
}