mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] Update flatmm related kernels (#3022)
--------- Co-authored-by: Ding, Yi <yi.ding@amd.com> Co-authored-by: felix <felix.li@amd.com>
This commit is contained in:
@@ -480,6 +480,14 @@ __global__ void naive_gemm_kernel(ADataType* A,
|
||||
else
|
||||
v_a = fp32_val.lo;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a]);
|
||||
if(k % 2 == 1)
|
||||
v_a = fp32_val.hi;
|
||||
else
|
||||
v_a = fp32_val.lo;
|
||||
}
|
||||
else
|
||||
{
|
||||
v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
|
||||
@@ -492,6 +500,14 @@ __global__ void naive_gemm_kernel(ADataType* A,
|
||||
else
|
||||
v_b = fp32_val.lo;
|
||||
}
|
||||
else if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(B[b_index / packed_size_b]);
|
||||
if(k % 2 == 1)
|
||||
v_b = fp32_val.hi;
|
||||
else
|
||||
v_b = fp32_val.lo;
|
||||
}
|
||||
else
|
||||
{
|
||||
v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
|
||||
@@ -506,6 +522,121 @@ __global__ void naive_gemm_kernel(ADataType* A,
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename LayoutA,
|
||||
typename LayoutB,
|
||||
typename LayoutC>
|
||||
__global__ void blockwise_gemm_kernel(ADataType* A,
|
||||
BDataType* B,
|
||||
CDataType* C,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t strideA,
|
||||
ck_tile::index_t strideB,
|
||||
ck_tile::index_t strideC,
|
||||
ck_tile::index_t scale_granularity_m,
|
||||
ck_tile::index_t scale_granularity_n,
|
||||
ck_tile::index_t scale_granularity_k,
|
||||
float* scale_A_ptr,
|
||||
float* scale_B_ptr)
|
||||
{
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int row = idx / N; // Compute row index
|
||||
int col = idx % N; // Compute column index
|
||||
|
||||
if(row < M && col < N)
|
||||
{
|
||||
AccDataType acc = 0.0, acc_temp = 0.0;
|
||||
|
||||
index_t scale_A_stride = (M + scale_granularity_m - 1) / scale_granularity_m;
|
||||
index_t scale_B_stride = (N + scale_granularity_n - 1) / scale_granularity_n;
|
||||
|
||||
float scale_A = 0;
|
||||
float scale_B = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
if(k % scale_granularity_k == 0)
|
||||
{
|
||||
// update acc
|
||||
acc += acc_temp * scale_A * scale_B;
|
||||
acc_temp = 0.0;
|
||||
// update scale factors
|
||||
scale_A = scale_A_ptr[(row / scale_granularity_m) +
|
||||
(k / scale_granularity_k) * scale_A_stride];
|
||||
scale_B = scale_B_ptr[(col / scale_granularity_n) +
|
||||
(k / scale_granularity_k) * scale_B_stride];
|
||||
}
|
||||
|
||||
constexpr index_t packed_size_a = ck_tile::numeric_traits<ADataType>::PackedSize;
|
||||
constexpr index_t packed_size_b = ck_tile::numeric_traits<BDataType>::PackedSize;
|
||||
// Adjust indexing based on matrix layout
|
||||
int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
|
||||
? row * strideA + k
|
||||
: k * strideA + row;
|
||||
int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
|
||||
? col * strideB + k
|
||||
: k * strideB + col;
|
||||
|
||||
AccDataType v_a;
|
||||
AccDataType v_b;
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]);
|
||||
if(k % 2 == 1)
|
||||
v_a = fp32_val.hi;
|
||||
else
|
||||
v_a = fp32_val.lo;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a]);
|
||||
if(k % 2 == 1)
|
||||
v_a = fp32_val.hi;
|
||||
else
|
||||
v_a = fp32_val.lo;
|
||||
}
|
||||
else
|
||||
{
|
||||
v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]);
|
||||
if(k % 2 == 1)
|
||||
v_b = fp32_val.hi;
|
||||
else
|
||||
v_b = fp32_val.lo;
|
||||
}
|
||||
else if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(B[b_index / packed_size_b], 1.0f);
|
||||
if(k % 2 == 1)
|
||||
v_b = fp32_val.hi;
|
||||
else
|
||||
v_b = fp32_val.lo;
|
||||
}
|
||||
else
|
||||
{
|
||||
v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
|
||||
}
|
||||
acc_temp += v_a * v_b;
|
||||
}
|
||||
// final accumulation
|
||||
acc += acc_temp * scale_A * scale_B;
|
||||
|
||||
int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
|
||||
? row * strideC + col
|
||||
: col * strideC + row;
|
||||
C[c_index] = ck_tile::type_convert<CDataType>(acc);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
@@ -534,6 +665,51 @@ void reference_gemm_gpu(ADataType* a_ptr,
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename LayoutA,
|
||||
typename LayoutB,
|
||||
typename LayoutC>
|
||||
void reference_blockwise_gemm_gpu(ADataType* a_ptr,
|
||||
BDataType* b_ptr,
|
||||
CDataType* c_ptr,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t stride_a,
|
||||
index_t stride_b,
|
||||
index_t stride_c,
|
||||
index_t scale_granularity_m,
|
||||
index_t scale_granularity_n,
|
||||
index_t scale_granularity_k,
|
||||
float* scale_A_ptr,
|
||||
float* scale_B_ptr)
|
||||
{
|
||||
int totalElements = M * N;
|
||||
int numThreadsPerBlock = 256; // Common choice for threads per block
|
||||
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
|
||||
|
||||
blockwise_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
|
||||
<<<numBlocks, numThreadsPerBlock>>>(a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c,
|
||||
scale_granularity_m,
|
||||
scale_granularity_n,
|
||||
scale_granularity_k,
|
||||
scale_A_ptr,
|
||||
scale_B_ptr);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
@@ -571,4 +747,5 @@ void reference_batched_gemm_gpu(ADataType* a_ptr,
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
316
include/ck_tile/host/reference/reference_moe_gemm.hpp
Normal file
316
include/ck_tile/host/reference/reference_moe_gemm.hpp
Normal file
@@ -0,0 +1,316 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <thread>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename LayoutA,
|
||||
typename LayoutB,
|
||||
typename LayoutC,
|
||||
int MoeGemmKind = 0, // 0: gemm1_gate_only, 1: gemm1_gate_up, 2: gemm2
|
||||
typename ActivationOp = identity>
|
||||
__global__ void moe_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
|
||||
const ck_tile::index_t* p_sorted_expert_ids_,
|
||||
const ck_tile::index_t* p_max_token_id_,
|
||||
const ADataType* A,
|
||||
const BDataType* B,
|
||||
CDataType* C,
|
||||
const AccDataType* expert_weight_ptr,
|
||||
ck_tile::index_t Num_tokens,
|
||||
ck_tile::index_t TokensPerBlock,
|
||||
ck_tile::index_t TopK,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t strideA,
|
||||
ck_tile::index_t strideB,
|
||||
ck_tile::index_t strideC,
|
||||
index_t scale_granularity_m,
|
||||
index_t scale_granularity_n,
|
||||
index_t scale_granularity_k,
|
||||
float* scale_A_ptr,
|
||||
float* scale_B_ptr,
|
||||
float* expert_bias_ptr)
|
||||
{
|
||||
int idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int problem_N = MoeGemmKind == 1 ? N / 2 : N;
|
||||
int row = idx / problem_N; // Compute row index
|
||||
int col = idx % problem_N; // Compute column index
|
||||
|
||||
index_t gather_token_id = 0;
|
||||
index_t scatter_token_id = 0;
|
||||
index_t expert_id = 0;
|
||||
|
||||
if(row < p_max_token_id_[0])
|
||||
{
|
||||
expert_id = p_sorted_expert_ids_[row / TokensPerBlock];
|
||||
gather_token_id = p_sorted_token_ids_[row] & 0xff'ffff;
|
||||
scatter_token_id = p_sorted_token_ids_[row] & 0xff'ffff;
|
||||
if(gather_token_id >= Num_tokens)
|
||||
{
|
||||
return;
|
||||
}
|
||||
if(MoeGemmKind == 2)
|
||||
{
|
||||
gather_token_id = gather_token_id * TopK + (p_sorted_token_ids_[row] >> 24);
|
||||
}
|
||||
else
|
||||
{
|
||||
scatter_token_id = scatter_token_id * TopK + (p_sorted_token_ids_[row] >> 24);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
if(row < M)
|
||||
{
|
||||
AccDataType acc = 0.0;
|
||||
AccDataType acc_up = 0.0;
|
||||
|
||||
AccDataType acc_temp = 0.0;
|
||||
AccDataType acc_up_temp = 0.0;
|
||||
|
||||
float scale_A = 0;
|
||||
float scale_B = 0;
|
||||
float scale_B_up = 0;
|
||||
|
||||
index_t scale_A_stride = (M + scale_granularity_m - 1) / scale_granularity_m;
|
||||
index_t scale_B_stride = (N + scale_granularity_n - 1) / scale_granularity_n;
|
||||
index_t scale_B_expert_stride = scale_B_stride * K / scale_granularity_k;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
if(k % scale_granularity_k == 0)
|
||||
{
|
||||
// update acc
|
||||
acc += acc_temp * scale_A * scale_B;
|
||||
acc_up += acc_up_temp * scale_A * scale_B_up;
|
||||
// reset acc temp
|
||||
acc_temp = 0.0;
|
||||
acc_up_temp = 0.0;
|
||||
// update scale factors
|
||||
scale_A = scale_A_ptr[(gather_token_id / scale_granularity_m) +
|
||||
(k / scale_granularity_k) * scale_A_stride];
|
||||
scale_B =
|
||||
scale_B_ptr[expert_id * scale_B_expert_stride + col / scale_granularity_n +
|
||||
(k / scale_granularity_k) * scale_B_stride];
|
||||
if constexpr(MoeGemmKind == 1)
|
||||
scale_B_up = scale_B_ptr[expert_id * scale_B_expert_stride +
|
||||
(col + problem_N) / scale_granularity_n +
|
||||
(k / scale_granularity_k) * scale_B_stride];
|
||||
}
|
||||
|
||||
constexpr index_t packed_size_a = ck_tile::numeric_traits<ADataType>::PackedSize;
|
||||
constexpr index_t packed_size_b = ck_tile::numeric_traits<BDataType>::PackedSize;
|
||||
// Adjust indexing based on matrix layout
|
||||
int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
|
||||
? gather_token_id * strideA + k
|
||||
: k * strideA + gather_token_id;
|
||||
|
||||
long b_index =
|
||||
long(expert_id) * N * K +
|
||||
((std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>) ? col * strideB + k
|
||||
: k * strideB + col);
|
||||
long b_index_up;
|
||||
if constexpr(MoeGemmKind == 1)
|
||||
b_index_up = long(expert_id) * N * K +
|
||||
((std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
|
||||
? (col + problem_N) * strideB + k
|
||||
: k * strideB + col + problem_N);
|
||||
|
||||
AccDataType v_a;
|
||||
AccDataType v_b;
|
||||
AccDataType v_b_up;
|
||||
if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]);
|
||||
if(k % 2 == 1)
|
||||
v_a = fp32_val.hi;
|
||||
else
|
||||
v_a = fp32_val.lo;
|
||||
}
|
||||
else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a]);
|
||||
if(k % 2 == 1)
|
||||
v_a = fp32_val.hi;
|
||||
else
|
||||
v_a = fp32_val.lo;
|
||||
}
|
||||
else
|
||||
{
|
||||
v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
|
||||
}
|
||||
if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]);
|
||||
if(k % 2 == 1)
|
||||
v_b = fp32_val.hi;
|
||||
else
|
||||
v_b = fp32_val.lo;
|
||||
if constexpr(MoeGemmKind == 1)
|
||||
{
|
||||
const fp32x2_t fp32_val_up =
|
||||
pk_int4_t_to_fp32x2_t(B[b_index_up / packed_size_b]);
|
||||
if(k % 2 == 1)
|
||||
v_b_up = fp32_val_up.hi;
|
||||
else
|
||||
v_b_up = fp32_val_up.lo;
|
||||
}
|
||||
}
|
||||
else if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
|
||||
{
|
||||
const fp32x2_t fp32_val = pk_fp4_to_fp32x2(B[b_index / packed_size_b], 1.0f);
|
||||
if(k % 2 == 1)
|
||||
v_b = fp32_val.hi;
|
||||
else
|
||||
v_b = fp32_val.lo;
|
||||
if constexpr(MoeGemmKind == 1)
|
||||
{
|
||||
const fp32x2_t fp32_val_up =
|
||||
pk_fp4_to_fp32x2(B[b_index_up / packed_size_b], 1.0f);
|
||||
if(k % 2 == 1)
|
||||
v_b_up = fp32_val_up.hi;
|
||||
else
|
||||
v_b_up = fp32_val_up.lo;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
|
||||
if constexpr(MoeGemmKind == 1)
|
||||
v_b_up = ck_tile::type_convert<AccDataType>(B[b_index_up]);
|
||||
}
|
||||
acc_temp += v_a * v_b;
|
||||
if constexpr(MoeGemmKind == 1)
|
||||
acc_up_temp += v_a * v_b_up;
|
||||
}
|
||||
|
||||
acc += acc_temp * scale_A * scale_B;
|
||||
acc_up += acc_up_temp * scale_A * scale_B_up;
|
||||
|
||||
float bias = 0.f, bias_up = 0.f;
|
||||
if(expert_bias_ptr != nullptr)
|
||||
{
|
||||
bias = expert_bias_ptr[expert_id * N + col];
|
||||
if constexpr(MoeGemmKind == 1)
|
||||
bias_up = expert_bias_ptr[expert_id * N + col + problem_N];
|
||||
}
|
||||
|
||||
int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
|
||||
? scatter_token_id * strideC + col
|
||||
: col * strideC + scatter_token_id;
|
||||
if constexpr(MoeGemmKind < 2)
|
||||
{
|
||||
C[c_index] = ck_tile::type_convert<CDataType>(
|
||||
ActivationOp{}(acc + bias, MoeGemmKind == 1 ? acc_up + bias_up : 1));
|
||||
}
|
||||
else
|
||||
{
|
||||
// moe gemm2 don't use activation.
|
||||
CDataType res = ck_tile::type_convert<CDataType>((acc + bias) * expert_weight_ptr[row]);
|
||||
using ResV2Type = std::conditional_t<std::is_same_v<CDataType, ck_tile::half_t>,
|
||||
ck_tile::fp16x2_t,
|
||||
ck_tile::bf16x2_t>;
|
||||
ResV2Type add_v{0, 0};
|
||||
if(c_index % 2)
|
||||
{
|
||||
// result is the second value of fp16 pair.
|
||||
add_v.y = res;
|
||||
}
|
||||
else
|
||||
{
|
||||
// result is the first value of fp16 pair.
|
||||
add_v.x = res;
|
||||
}
|
||||
// mask last bit to make sure atomicAdd pointer is aligned of DWORD.
|
||||
atomic_add<ResV2Type>(reinterpret_cast<ResV2Type*>(C + (c_index & 0xffff'fffe)), add_v);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename LayoutA,
|
||||
typename LayoutB,
|
||||
typename LayoutC,
|
||||
int MoeGemmKind = 0, // 0: gemm1_gate_only, 1: gemm1_gate_up, 2: gemm2
|
||||
typename ActivationOp = identity>
|
||||
void reference_moe_gemm_gpu(const index_t* p_sorted_token_ids_,
|
||||
const index_t* p_sorted_expert_ids_,
|
||||
const index_t* p_max_token_id_,
|
||||
const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
CDataType* c_ptr,
|
||||
const AccDataType* expert_weight_ptr,
|
||||
index_t Num_tokens,
|
||||
index_t TokensPerBlock,
|
||||
index_t TopK,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t stride_a,
|
||||
index_t stride_b,
|
||||
index_t stride_c,
|
||||
index_t scale_granularity_m,
|
||||
index_t scale_granularity_n,
|
||||
index_t scale_granularity_k,
|
||||
float* scale_A_ptr,
|
||||
float* scale_B_ptr,
|
||||
float* exp_bias = nullptr)
|
||||
{
|
||||
int problem_N = MoeGemmKind == 1 ? N / 2 : N;
|
||||
int totalElements = M * problem_N;
|
||||
int numThreadsPerBlock = 256; // Common choice for threads per block
|
||||
int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
|
||||
|
||||
moe_gemm_kernel<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
LayoutA,
|
||||
LayoutB,
|
||||
LayoutC,
|
||||
MoeGemmKind,
|
||||
ActivationOp><<<numBlocks, numThreadsPerBlock>>>(p_sorted_token_ids_,
|
||||
p_sorted_expert_ids_,
|
||||
p_max_token_id_,
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
expert_weight_ptr,
|
||||
Num_tokens,
|
||||
TokensPerBlock,
|
||||
TopK,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_a,
|
||||
stride_b,
|
||||
stride_c,
|
||||
scale_granularity_m,
|
||||
scale_granularity_n,
|
||||
scale_granularity_k,
|
||||
scale_A_ptr,
|
||||
scale_B_ptr,
|
||||
exp_bias);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user