add fp16xf4 moe

This commit is contained in:
Feng Shijie
2025-08-18 17:28:11 +00:00
parent 599e1f5b32
commit be55c0f9cb
10 changed files with 1345 additions and 214 deletions

View File

@@ -82,11 +82,13 @@ __global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
AccDataType acc_temp = 0.0;
AccDataType acc_up_temp = 0.0;
float scale_A = 0;
float scale_B = 0;
float scale_B_up = 0;
index_t scale_A_stride = (M + scale_granularity_m - 1) / scale_granularity_m;
index_t scale_B_stride = (N + scale_granularity_n - 1) / scale_granularity_n;
float scale_A = 0;
float scale_B = 0;
float scale_B_up = 0;
index_t scale_A_stride = (M + scale_granularity_m - 1) / scale_granularity_m;
index_t scale_B_stride = (N + scale_granularity_n - 1) / scale_granularity_n;
index_t scale_B_expert_stride = scale_B_stride * K / scale_granularity_k;
for(int k = 0; k < K; ++k)
{
@@ -101,12 +103,13 @@ __global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
// update scale factors
scale_A = scale_A_ptr[(gather_token_id / scale_granularity_m) +
(k / scale_granularity_k) * scale_A_stride];
scale_B = scale_B_ptr[((expert_id * N + col) / scale_granularity_n) +
(k / scale_granularity_k) * scale_B_stride];
scale_B =
scale_B_ptr[expert_id * scale_B_expert_stride + col / scale_granularity_n +
(k / scale_granularity_k) * scale_B_stride];
if constexpr(MoeGemmKind == 1)
scale_B_up =
scale_B_ptr[((expert_id * N + col + problem_N) / scale_granularity_n) +
(k / scale_granularity_k) * scale_B_stride];
scale_B_up = scale_B_ptr[expert_id * scale_B_expert_stride +
(col + problem_N) / scale_granularity_n +
(k / scale_granularity_k) * scale_B_stride];
}
constexpr index_t packed_size_a = ck_tile::numeric_traits<ADataType>::PackedSize;
@@ -138,6 +141,14 @@ __global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
else
v_a = fp32_val.lo;
}
else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
{
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a]);
if(k % 2 == 1)
v_a = fp32_val.hi;
else
v_a = fp32_val.lo;
}
else
{
v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
@@ -159,6 +170,22 @@ __global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
v_b_up = fp32_val_up.lo;
}
}
else if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
{
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(B[b_index / packed_size_b]);
if(k % 2 == 1)
v_b = fp32_val.hi;
else
v_b = fp32_val.lo;
if constexpr(MoeGemmKind == 1)
{
const fp32x2_t fp32_val_up = pk_fp4_to_fp32x2(B[b_index_up / packed_size_b]);
if(k % 2 == 1)
v_b_up = fp32_val_up.hi;
else
v_b_up = fp32_val_up.lo;
}
}
else
{
v_b = ck_tile::type_convert<AccDataType>(B[b_index]);