Files
composable_kernel/include/ck_tile/host/reference/reference_moe_gemm.hpp

317 lines
13 KiB
C++

// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <cstdlib>
#include <thread>
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace ck_tile {
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC,
int MoeGemmKind = 0, // 0: gemm1_gate_only, 1: gemm1_gate_up, 2: gemm2
typename ActivationOp = identity>
__global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
const ck_tile::index_t* p_sorted_expert_ids_,
const ck_tile::index_t* p_max_token_id_,
const ADataType* A,
const BDataType* B,
CDataType* C,
const AccDataType* expert_weight_ptr,
ck_tile::index_t Num_tokens,
ck_tile::index_t TokensPerBlock,
ck_tile::index_t TopK,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t strideA,
ck_tile::index_t strideB,
ck_tile::index_t strideC,
index_t scale_granularity_m,
index_t scale_granularity_n,
index_t scale_granularity_k,
float* scale_A_ptr,
float* scale_B_ptr,
float* expert_bias_ptr)
{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int problem_N = MoeGemmKind == 1 ? N / 2 : N;
int row = idx / problem_N; // Compute row index
int col = idx % problem_N; // Compute column index
index_t gather_token_id = 0;
index_t scatter_token_id = 0;
index_t expert_id = 0;
if(row < p_max_token_id_[0])
{
expert_id = p_sorted_expert_ids_[row / TokensPerBlock];
gather_token_id = p_sorted_token_ids_[row] & 0xff'ffff;
scatter_token_id = p_sorted_token_ids_[row] & 0xff'ffff;
if(gather_token_id >= Num_tokens)
{
return;
}
if(MoeGemmKind == 2)
{
gather_token_id = gather_token_id * TopK + (p_sorted_token_ids_[row] >> 24);
}
else
{
scatter_token_id = scatter_token_id * TopK + (p_sorted_token_ids_[row] >> 24);
}
}
else
{
return;
}
if(row < M)
{
AccDataType acc = 0.0;
AccDataType acc_up = 0.0;
AccDataType acc_temp = 0.0;
AccDataType acc_up_temp = 0.0;
float scale_A = 0;
float scale_B = 0;
float scale_B_up = 0;
index_t scale_A_stride = (M + scale_granularity_m - 1) / scale_granularity_m;
index_t scale_B_stride = (N + scale_granularity_n - 1) / scale_granularity_n;
index_t scale_B_expert_stride = scale_B_stride * K / scale_granularity_k;
for(int k = 0; k < K; ++k)
{
if(k % scale_granularity_k == 0)
{
// update acc
acc += acc_temp * scale_A * scale_B;
acc_up += acc_up_temp * scale_A * scale_B_up;
// reset acc temp
acc_temp = 0.0;
acc_up_temp = 0.0;
// update scale factors
scale_A = scale_A_ptr[(gather_token_id / scale_granularity_m) +
(k / scale_granularity_k) * scale_A_stride];
scale_B =
scale_B_ptr[expert_id * scale_B_expert_stride + col / scale_granularity_n +
(k / scale_granularity_k) * scale_B_stride];
if constexpr(MoeGemmKind == 1)
scale_B_up = scale_B_ptr[expert_id * scale_B_expert_stride +
(col + problem_N) / scale_granularity_n +
(k / scale_granularity_k) * scale_B_stride];
}
constexpr index_t packed_size_a = ck_tile::numeric_traits<ADataType>::PackedSize;
constexpr index_t packed_size_b = ck_tile::numeric_traits<BDataType>::PackedSize;
// Adjust indexing based on matrix layout
int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
? gather_token_id * strideA + k
: k * strideA + gather_token_id;
long b_index =
long(expert_id) * N * K +
((std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>) ? col * strideB + k
: k * strideB + col);
long b_index_up;
if constexpr(MoeGemmKind == 1)
b_index_up = long(expert_id) * N * K +
((std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
? (col + problem_N) * strideB + k
: k * strideB + col + problem_N);
AccDataType v_a;
AccDataType v_b;
AccDataType v_b_up;
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
{
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]);
if(k % 2 == 1)
v_a = fp32_val.hi;
else
v_a = fp32_val.lo;
}
else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
{
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a]);
if(k % 2 == 1)
v_a = fp32_val.hi;
else
v_a = fp32_val.lo;
}
else
{
v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
}
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
{
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]);
if(k % 2 == 1)
v_b = fp32_val.hi;
else
v_b = fp32_val.lo;
if constexpr(MoeGemmKind == 1)
{
const fp32x2_t fp32_val_up =
pk_int4_t_to_fp32x2_t(B[b_index_up / packed_size_b]);
if(k % 2 == 1)
v_b_up = fp32_val_up.hi;
else
v_b_up = fp32_val_up.lo;
}
}
else if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
{
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(B[b_index / packed_size_b], 1.0f);
if(k % 2 == 1)
v_b = fp32_val.hi;
else
v_b = fp32_val.lo;
if constexpr(MoeGemmKind == 1)
{
const fp32x2_t fp32_val_up =
pk_fp4_to_fp32x2(B[b_index_up / packed_size_b], 1.0f);
if(k % 2 == 1)
v_b_up = fp32_val_up.hi;
else
v_b_up = fp32_val_up.lo;
}
}
else
{
v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
if constexpr(MoeGemmKind == 1)
v_b_up = ck_tile::type_convert<AccDataType>(B[b_index_up]);
}
acc_temp += v_a * v_b;
if constexpr(MoeGemmKind == 1)
acc_up_temp += v_a * v_b_up;
}
acc += acc_temp * scale_A * scale_B;
acc_up += acc_up_temp * scale_A * scale_B_up;
float bias = 0.f, bias_up = 0.f;
if(expert_bias_ptr != nullptr)
{
bias = expert_bias_ptr[expert_id * N + col];
if constexpr(MoeGemmKind == 1)
bias_up = expert_bias_ptr[expert_id * N + col + problem_N];
}
int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
? scatter_token_id * strideC + col
: col * strideC + scatter_token_id;
if constexpr(MoeGemmKind < 2)
{
C[c_index] = ck_tile::type_convert<CDataType>(
ActivationOp{}(acc + bias, MoeGemmKind == 1 ? acc_up + bias_up : 1));
}
else
{
// moe gemm2 don't use activation.
CDataType res = ck_tile::type_convert<CDataType>((acc + bias) * expert_weight_ptr[row]);
using ResV2Type = std::conditional_t<std::is_same_v<CDataType, ck_tile::half_t>,
ck_tile::fp16x2_t,
ck_tile::bf16x2_t>;
ResV2Type add_v{0, 0};
if(c_index % 2)
{
// result is the second value of fp16 pair.
add_v.y = res;
}
else
{
// result is the first value of fp16 pair.
add_v.x = res;
}
// mask last bit to make sure atomicAdd pointer is aligned of DWORD.
atomic_add<ResV2Type>(reinterpret_cast<ResV2Type*>(C + (c_index & 0xffff'fffe)), add_v);
}
}
}
template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename LayoutA,
typename LayoutB,
typename LayoutC,
int MoeGemmKind = 0, // 0: gemm1_gate_only, 1: gemm1_gate_up, 2: gemm2
typename ActivationOp = identity>
void reference_moe_gemm_gpu(const index_t* p_sorted_token_ids_,
const index_t* p_sorted_expert_ids_,
const index_t* p_max_token_id_,
const ADataType* a_ptr,
const BDataType* b_ptr,
CDataType* c_ptr,
const AccDataType* expert_weight_ptr,
index_t Num_tokens,
index_t TokensPerBlock,
index_t TopK,
index_t M,
index_t N,
index_t K,
index_t stride_a,
index_t stride_b,
index_t stride_c,
index_t scale_granularity_m,
index_t scale_granularity_n,
index_t scale_granularity_k,
float* scale_A_ptr,
float* scale_B_ptr,
float* exp_bias = nullptr)
{
int problem_N = MoeGemmKind == 1 ? N / 2 : N;
int totalElements = M * problem_N;
int numThreadsPerBlock = 256; // Common choice for threads per block
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
moe_gemm_kernel<ADataType,
BDataType,
AccDataType,
CDataType,
LayoutA,
LayoutB,
LayoutC,
MoeGemmKind,
ActivationOp><<<numBlocks, numThreadsPerBlock>>>(p_sorted_token_ids_,
p_sorted_expert_ids_,
p_max_token_id_,
a_ptr,
b_ptr,
c_ptr,
expert_weight_ptr,
Num_tokens,
TokensPerBlock,
TopK,
M,
N,
K,
stride_a,
stride_b,
stride_c,
scale_granularity_m,
scale_granularity_n,
scale_granularity_k,
scale_A_ptr,
scale_B_ptr,
exp_bias);
return;
}
} // namespace ck_tile