[CK_TILE] support split-k a16w4 gemm1 (#3389)

* initial version to support moe gemm1 split-k

* add missing args

* fix build warning

* update reference

* for split-k disable bias and weight

* remove debug log

* fix format

* fix div by zero errors

* fix cmake config

* update

* resolve conflicts

* remove useless changes

* reformat

* fix

* remove useless changes

* fix ci

---------

Co-authored-by: lalala-sh <Jiaxing.Wen@amd.com>
Co-authored-by: root <root@smci355-ccs-aus-m01-25.cs-aus.dcgpu>
This commit is contained in:
yadaish
2025-12-29 23:05:35 +08:00
committed by GitHub
parent a0acc83a72
commit dae85ead64
11 changed files with 136 additions and 78 deletions

View File

@@ -18,7 +18,7 @@ template <typename ADataType,
typename LayoutA,
typename LayoutB,
typename LayoutC,
int MoeGemmKind = 0, // 0: gemm1_gate_only, 1: gemm1_gate_up, 2: gemm2
int MoeGemmKind = 0, // 0: gemm1_gate_only, 1: gemm1_gate_up, 2: gemm2, 3:gemm1_split_k
typename ActivationOp = identity>
__global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
const ck_tile::index_t* p_sorted_expert_ids_,
@@ -43,10 +43,11 @@ __global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
float* scale_B_ptr,
float* expert_bias_ptr)
{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int problem_N = MoeGemmKind == 1 ? N / 2 : N;
int row = idx / problem_N; // Compute row index
int col = idx % problem_N; // Compute column index
constexpr auto is_split_k = MoeGemmKind == 3;
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int problem_N = MoeGemmKind == 1 ? N / 2 : N;
int row = idx / problem_N; // Compute row index
int col = idx % problem_N; // Compute column index
index_t gather_token_id = 0;
index_t scatter_token_id = 0;
@@ -203,7 +204,7 @@ __global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
acc_up += acc_up_temp * scale_A * scale_B_up;
float bias = 0.f, bias_up = 0.f;
if(expert_bias_ptr != nullptr)
if(expert_bias_ptr != nullptr && !is_split_k)
{
bias = expert_bias_ptr[expert_id * N + col];
if constexpr(MoeGemmKind == 1)
@@ -221,23 +222,24 @@ __global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
else
{
// moe gemm2 don't use activation.
CDataType res = ck_tile::type_convert<CDataType>((acc + bias) * expert_weight_ptr[row]);
using ResV2Type = std::conditional_t<std::is_same_v<CDataType, ck_tile::half_t>,
ck_tile::fp16x2_t,
ck_tile::bf16x2_t>;
ResV2Type add_v{0, 0};
auto weight =
is_split_k ? ck_tile::type_convert<AccDataType>(1.0f) : expert_weight_ptr[row];
CDataType res = ck_tile::type_convert<CDataType>((acc + bias) * weight);
thread_buffer<CDataType, 2> add_v = 0;
if(c_index % 2)
{
// result is the second value of fp16 pair.
add_v.y = res;
add_v.template get_as<CDataType>()[1] = res;
}
else
{
// result is the first value of fp16 pair.
add_v.x = res;
add_v.template get_as<CDataType>()[0] = res;
}
// mask last bit to make sure atomicAdd pointer is aligned of DWORD.
atomic_add<ResV2Type>(reinterpret_cast<ResV2Type*>(C + (c_index & 0xffff'fffe)), add_v);
atomic_add_g<CDataType, 2>(reinterpret_cast<CDataType*>(C + (c_index & 0xffff'fffe)),
add_v);
}
}
}
@@ -249,7 +251,7 @@ template <typename ADataType,
typename LayoutA,
typename LayoutB,
typename LayoutC,
int MoeGemmKind = 0, // 0: gemm1_gate_only, 1: gemm1_gate_up, 2: gemm2
int MoeGemmKind = 0, // 0: gemm1_gate_only, 1: gemm1_gate_up, 2: gemm2, 3:gemm1_split_k
typename ActivationOp = identity>
void reference_moe_gemm_gpu(const index_t* p_sorted_token_ids_,
const index_t* p_sorted_expert_ids_,