mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] support split-k a16w4 gemm1 (#3389)
* initial version to support moe gemm1 split-k * add missing args * fix build warning * update reference * for split-k disable bias and weight * remove debug log * fix format * fix div by zero errors * fix cmake config * update * resolve conflicts * remove useless changes * reformat * fix * remove useless changes * fix ci --------- Co-authored-by: lalala-sh <Jiaxing.Wen@amd.com> Co-authored-by: root <root@smci355-ccs-aus-m01-25.cs-aus.dcgpu>
This commit is contained in:
@@ -18,7 +18,7 @@ template <typename ADataType,
|
||||
typename LayoutA,
|
||||
typename LayoutB,
|
||||
typename LayoutC,
|
||||
int MoeGemmKind = 0, // 0: gemm1_gate_only, 1: gemm1_gate_up, 2: gemm2
|
||||
int MoeGemmKind = 0, // 0: gemm1_gate_only, 1: gemm1_gate_up, 2: gemm2, 3:gemm1_split_k
|
||||
typename ActivationOp = identity>
|
||||
__global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
|
||||
const ck_tile::index_t* p_sorted_expert_ids_,
|
||||
@@ -43,10 +43,11 @@ __global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
|
||||
float* scale_B_ptr,
|
||||
float* expert_bias_ptr)
|
||||
{
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int problem_N = MoeGemmKind == 1 ? N / 2 : N;
|
||||
int row = idx / problem_N; // Compute row index
|
||||
int col = idx % problem_N; // Compute column index
|
||||
constexpr auto is_split_k = MoeGemmKind == 3;
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int problem_N = MoeGemmKind == 1 ? N / 2 : N;
|
||||
int row = idx / problem_N; // Compute row index
|
||||
int col = idx % problem_N; // Compute column index
|
||||
|
||||
index_t gather_token_id = 0;
|
||||
index_t scatter_token_id = 0;
|
||||
@@ -203,7 +204,7 @@ __global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
|
||||
acc_up += acc_up_temp * scale_A * scale_B_up;
|
||||
|
||||
float bias = 0.f, bias_up = 0.f;
|
||||
if(expert_bias_ptr != nullptr)
|
||||
if(expert_bias_ptr != nullptr && !is_split_k)
|
||||
{
|
||||
bias = expert_bias_ptr[expert_id * N + col];
|
||||
if constexpr(MoeGemmKind == 1)
|
||||
@@ -221,23 +222,24 @@ __global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
|
||||
else
|
||||
{
|
||||
// moe gemm2 don't use activation.
|
||||
CDataType res = ck_tile::type_convert<CDataType>((acc + bias) * expert_weight_ptr[row]);
|
||||
using ResV2Type = std::conditional_t<std::is_same_v<CDataType, ck_tile::half_t>,
|
||||
ck_tile::fp16x2_t,
|
||||
ck_tile::bf16x2_t>;
|
||||
ResV2Type add_v{0, 0};
|
||||
auto weight =
|
||||
is_split_k ? ck_tile::type_convert<AccDataType>(1.0f) : expert_weight_ptr[row];
|
||||
CDataType res = ck_tile::type_convert<CDataType>((acc + bias) * weight);
|
||||
|
||||
thread_buffer<CDataType, 2> add_v = 0;
|
||||
if(c_index % 2)
|
||||
{
|
||||
// result is the second value of fp16 pair.
|
||||
add_v.y = res;
|
||||
add_v.template get_as<CDataType>()[1] = res;
|
||||
}
|
||||
else
|
||||
{
|
||||
// result is the first value of fp16 pair.
|
||||
add_v.x = res;
|
||||
add_v.template get_as<CDataType>()[0] = res;
|
||||
}
|
||||
// mask last bit to make sure atomicAdd pointer is aligned of DWORD.
|
||||
atomic_add<ResV2Type>(reinterpret_cast<ResV2Type*>(C + (c_index & 0xffff'fffe)), add_v);
|
||||
atomic_add_g<CDataType, 2>(reinterpret_cast<CDataType*>(C + (c_index & 0xffff'fffe)),
|
||||
add_v);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -249,7 +251,7 @@ template <typename ADataType,
|
||||
typename LayoutA,
|
||||
typename LayoutB,
|
||||
typename LayoutC,
|
||||
int MoeGemmKind = 0, // 0: gemm1_gate_only, 1: gemm1_gate_up, 2: gemm2
|
||||
int MoeGemmKind = 0, // 0: gemm1_gate_only, 1: gemm1_gate_up, 2: gemm2, 3:gemm1_split_k
|
||||
typename ActivationOp = identity>
|
||||
void reference_moe_gemm_gpu(const index_t* p_sorted_token_ids_,
|
||||
const index_t* p_sorted_expert_ids_,
|
||||
|
||||
Reference in New Issue
Block a user