Merge branch 'moe_bs_stage1_dev' into moe_merge_v3_bs_for_aiter

This commit is contained in:
OscarXu
2025-05-17 00:43:46 -05:00
13 changed files with 3824 additions and 460 deletions

View File

@@ -41,6 +41,7 @@ using DsDataType = ck::Tuple<>;
using EDataType = BF16;
using A0Layout = Row;
using A1Layout = Col;
using B0Layout = Col;
using D0Layout = Row;
using D1Layout = Col;
@@ -158,7 +159,8 @@ int main(int argc, char* argv[])
exit(0);
}
ck::index_t Scale_Stride_AM = (K + Scale_Block_K - 1) / Scale_Block_K;
// Transpose the AScale tensor for better performance
ck::index_t Scale_Stride_AK = (M + Scale_Block_M - 1) / Scale_Block_M;
ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K;
auto f_host_tensor_descriptor =
@@ -178,8 +180,8 @@ int main(int argc, char* argv[])
Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{}));
Tensor<A1DataType> a1_m_k(f_host_tensor_descriptor((M + Scale_Block_M - 1) / Scale_Block_M,
(K + Scale_Block_K - 1) / Scale_Block_K,
Scale_Stride_AM,
A0Layout{}));
Scale_Stride_AK,
A1Layout{}));
Tensor<B0DataType> b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{}));
Tensor<B0DataType> b0_preshuffled(
f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); // use laout only for size
@@ -196,7 +198,6 @@ int main(int argc, char* argv[])
std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl;
std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl;
#if 1
switch(init_method)
{
case 0: break;
@@ -236,17 +237,6 @@ int main(int argc, char* argv[])
a1_m_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0, 1.0});
b1_k_n.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
}
#endif
#if 0
for(int im =0; im< (M + Scale_Block_M - 1) / Scale_Block_M; im++){
float row_sum = .0;
for(int ik =0; ik< (K + Scale_Block_K - 1) / Scale_Block_K; ik++){
printf("%lf ",a1_m_k(im, ik));
row_sum += a1_m_k(im, ik);
}
printf("sum: %lf\n", row_sum * 128);
}
#endif
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize());
DeviceMem a1_device_buf(sizeof(A1DataType) * a1_m_k.mDesc.GetElementSpaceSize());

View File

@@ -0,0 +1,547 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_moe_gemm1_blockscale.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using F8 = ck::f8_t;
using F32 = float;
using I64 = int64_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = F8;
using A1DataType = F32;
using B0DataType = F8;
using B1DataType = F32;
using EDataType = F16;
using AccDataType = F32;
using CShuffleDataType = EDataType;
using D2DataType = F32;
using DsDataType = ck::Tuple<D2DataType>;
using A0Layout = Row;
using B0Layout = Col;
using ELayout = Row;
using D0Layout = Row;
using D1Layout = Col;
using D2Layout = ELayout;
using DsLayout = ck::Tuple<D2Layout>;
struct MulABScaleExpertWeight
{
template <typename E, typename C, typename D2>
__host__ __device__ constexpr void operator()(E& e, const C& c, const D2& d2) const;
// for real kernel use
template <>
__host__ __device__ constexpr void
operator()<EDataType, float, float>(EDataType& e, const float& c, const float& d2) const
{
// for real kernel use
(void)d2;
e = ck::type_convert<EDataType>(c);
}
template <>
__host__ __device__ constexpr void operator()<EDataType, EDataType, float>(
EDataType& e, const EDataType& c, const float& d2) const
{
(void)d2;
e = ck::type_convert<EDataType>(c);
}
// for reference cpu
template <>
__host__ __device__ constexpr void
operator()<float, float, float>(float& e, const float& c, const float& d2) const
{
// for reference cpu
(void)d2;
e = ck::type_convert<EDataType>(c);
}
};
using CDEElementOp = MulABScaleExpertWeight;
void preShuffleBuffer(const B0DataType* src, B0DataType* dst, int N, int K, int NXdl)
{
int KPack = 16 / sizeof(B0DataType);
int NLane = NXdl;
int KLane = 64 / NLane;
int K0 = K / (KLane * KPack);
// K -> K0 KLane KPack
// N -> N0 NLane
// N, K -> N0 K0 KLane NLane KPack
int tempk;
for(I64 n = 0; n < N; ++n)
{
for(I64 k = 0; k < K; ++k)
{
I64 n0 = n / NLane;
I64 n1 = n % NLane;
I64 k0 = k / (KLane * KPack);
tempk = k % (KLane * KPack);
I64 k1 = tempk / KPack;
I64 k2 = tempk % KPack;
I64 outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
k1 * KPack * NLane + n1 * KPack + k2;
dst[outputIndex] = src[n * static_cast<I64>(K) + k];
}
}
}
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr ck::index_t Scale_Block_M = 1;
static constexpr ck::index_t Scale_Block_N = 128;
static constexpr ck::index_t Scale_Block_K = 128;
static constexpr ck::index_t Nswizzle = false;
static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul
static constexpr bool MulRoutedWeight = false;
#if 0
static constexpr ck::index_t MPerBlock = 32;
static constexpr ck::index_t NPerBlock = 128;
static constexpr ck::index_t MNPerXDL = 16;
static constexpr ck::index_t MXDLPerWave = MPerBlock / (MNPerXDL * 1);
static constexpr ck::index_t NXDLPerWave = NPerBlock / (MNPerXDL * 4);
static constexpr ck::index_t CShuffleMXDLPerWave = MXDLPerWave;
static constexpr ck::index_t CShuffleNXDLPerWave = NXDLPerWave;
static constexpr ck::index_t BLOCKSIZE = 256;
static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
static constexpr ck::index_t EVec = 16 / sizeof(EDataType);
static constexpr ck::index_t D0Vec = 1;
static constexpr ck::index_t D1Vec = 1;
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale
// clang-format off
< Row, Col, DsLayout, ELayout,
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
//threadnum, mblock, nblock, kblock
BLOCKSIZE, Scale_Block_M, Scale_Block_N, Scale_Block_K,
MPerBlock, NPerBlock, KPerBlock,
// ak1, bk1
AK1, BK1,
// mn_perxdl
MNPerXDL, MNPerXDL,
// mn_xdlperwave
MXDLPerWave, NXDLPerWave,
// a,b: loadtranfer cluster, cluster order, srcorder,VECDIM, srcpervec, dstpervec, lds_extra
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0,
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
CShuffleMXDLPerWave, CShuffleNXDLPerWave, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, true, MulRoutedWeight, int32_t, A0DataType>;
#else
static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
Row, Col, DsLayout, ELayout,
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
256, Scale_Block_M, Scale_Block_N, Scale_Block_K,
MPerBlock, 128, 128,
16, 16,
16, 16,
4, 2,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
4, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, true, MulRoutedWeight, int32_t, A0DataType>;
#endif
// clang-format on
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
#if 1
// GEMM shape
ck::index_t N = 4096;
ck::index_t K = 6144;
ck::index_t experts = 8;
ck::index_t topk = 2;
// ck::index_t sorted_tile_num = 133;
// ck::index_t valid_tile_num = 128;
// ck::index_t tokens = 8192;
// ck::index_t sorted_tile_num = 15;
// ck::index_t valid_tile_num = 13;
ck::index_t sorted_tile_num = 55;
ck::index_t valid_tile_num = 52;
ck::index_t tokens = 832;
#else
//deepseek
ck::index_t N = 2048;
ck::index_t K = 7168;
ck::index_t experts = 256;
ck::index_t topk = 8;
ck::index_t tokens = 4096;
ck::index_t sorted_tile_num = 261;
ck::index_t valid_tile_num = 256;
#endif
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
// use default case
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 7)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
N = std::stoi(argv[4]);
K = std::stoi(argv[5]);
tokens = std::stoi(argv[6]);
}
else if(argc == 9)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
N = std::stoi(argv[4]);
K = std::stoi(argv[5]);
tokens = std::stoi(argv[6]);
sorted_tile_num = std::stoi(argv[7]);
valid_tile_num = std::stoi(argv[8]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 6: N, K, tokens\n");
exit(0);
}
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;
if(tokens * topk > valid_size)
{
printf("err config, tokens * topk > valid_size\n");
exit(-1);
}
ck::index_t StrideA = K;
ck::index_t StrideB = K;
ck::index_t StrideE = N;
constexpr ck::index_t NumDTensor = DsDataType::Size();
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0};
ck::index_t Scale_Stride_AM = (K + Scale_Block_K - 1) / Scale_Block_K;
ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K;
ck::index_t Scale_Stride_B = (N + Scale_Block_N - 1) / Scale_Block_N * 2;
ck::index_t KBatch = 1;
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1 + sorted_tile_num}));
max_token_id.mData = {valid_size};
// int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3};
for(int i = 0; i < sorted_tile_num; i++)
{
expert_ids.mData[i] = i / ck::math::integer_divide_ceil(valid_tile_num, experts);
}
int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num;
int tokenid = 0;
for(int i = 0; i < sorted_size; i++)
{
int tile_off = i % MPerBlock;
if(tile_off < token_per_tile && tokenid < tokens * topk)
{
sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24);
tokenid++;
}
else
{
sorted_token_ids.mData[i] = tokens;
}
}
Tensor<A0DataType> a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1}));
Tensor<A1DataType> a1_t_k(HostTensorDescriptor({tokens, (K + Scale_Block_K - 1) / Scale_Block_K},
{Scale_Stride_AM, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}));
Tensor<B1DataType> b1_e_n_k(HostTensorDescriptor(
{experts, (K + Scale_Block_K - 1) / Scale_Block_K, (N + Scale_Block_N - 1) / Scale_Block_N * 2},
{(Scale_Stride_B * Scale_Stride_BN), 1, Scale_Stride_BN}));
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}));
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
Tensor<EDataType> e_t_n_device_result(
HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
e_t_n_device_result.SetZero();
std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl;
std::cout << "a1_t_k: " << a1_t_k.mDesc << std::endl;
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl;
std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl;
std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5});
a1_t_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
break;
case 2:
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
a1_t_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 3:
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
a1_t_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
break;
case 4:
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5});
a1_t_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
break;
case 5:
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5});
a1_t_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
break;
case 6:
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5});
a1_t_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
break;
default:
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5});
a1_t_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
}
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) *
sorted_token_ids.mDesc.GetElementSpaceSize());
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize());
DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize());
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.mDesc.GetElementSpaceSize());
DeviceMem a1_device_buf(sizeof(A1DataType) * a1_t_k.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize());
DeviceMem b1_device_buf(sizeof(B1DataType) * b1_e_n_k.mDesc.GetElementSpaceSize());
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize());
// a0_t_k.savetxt("a.txt");
expert_ids.savetxt("expert_ids.txt", "int");
sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
// d2_e_n.savetxt("d2_e_n.txt", "int");
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
expert_ids_dev.ToDevice(expert_ids.mData.data());
max_token_id_dev.ToDevice(max_token_id.mData.data());
a0_device_buf.ToDevice(a0_t_k.mData.data());
a1_device_buf.ToDevice(a1_t_k.mData.data());
b1_device_buf.ToDevice(b1_e_n_k.mData.data());
d2_device_buf.ToDevice(d2_e_n.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
// do GEMM
auto device_op = DeviceOpInstance{};
int NPerXdl = device_op.GetPreShuffleParameters();
preShuffleBuffer(
b0_e_n_k.mData.data(), b0_preshuffled.mData.data(), N * 2 * experts, K, NPerXdl);
b0_device_buf.ToDevice(b0_preshuffled.mData.data());
auto invoker = device_op.MakeInvoker();
auto argument =
device_op.MakeArgument(sorted_token_ids_dev.GetDeviceBuffer(),
expert_ids_dev.GetDeviceBuffer(),
max_token_id_dev.GetDeviceBuffer(),
a0_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(),
std::array<const void*, NumDTensor>{d2_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
tokens,
topk,
sorted_size,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideE,
a1_device_buf.GetDeviceBuffer(),
b1_device_buf.GetDeviceBuffer(),
KBatch,
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
if(time_kernel)
{
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * tokens * topk * N * 2 * K;
std::size_t num_btype = sizeof(A0DataType) * valid_tile_num * K +
sizeof(B0DataType) * K * N * 2 * experts +
sizeof(EDataType) * valid_tile_num * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s.\n" << device_op.GetTypeString() << std::endl;
}
if(do_verification)
{
invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1});
Tensor<float> a_t_k({tokens, K});
Tensor<float> b_e_n_k({experts, K, N * 2});
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
Tensor<CShuffleDataType> c_t_k_n({tokens, topk, N}, {topk * N, N, 1});
//handle scale before ref.
for(int t = 0; t < tokens; ++t)
{
for(int k = 0; k < K; ++k)
{
a_t_k(t, k) = ck::type_convert<float>(a0_t_k(t, k)) *
a1_t_k(t, k / Scale_Block_K);
}
}
for(int e = 0; e < experts; ++e)
{
for(int k = 0; k < K; ++k)
{
for(int n = 0; n < N * 2; ++n)
{
b_e_n_k(e, k, n) = ck::type_convert<float>(b0_e_n_k(e, k, n)) *
b1_e_n_k(e, k / Scale_Block_K, n / Scale_Block_N);
}
}
}
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceMoeGemm1BlockScale<float,
float,
CShuffleDataType,
D2DataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough,
ActOP,
MulRoutedWeight>;
auto ref_moe_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_moe_gemm.MakeInvoker();
auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids,
expert_ids,
max_token_id,
MPerBlock,
a_t_k,
b_e_n_k,
d2_e_n,
c_t_k_n,
PassThrough{},
PassThrough{},
PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < valid_size; ++m)
{
const int fuse_t = sorted_token_ids.mData[m];
const int t = fuse_t & 0xffffff;
const int topk_id = (fuse_t & 0xff000000) >> 24;
if(t >= tokens)
{
continue;
}
for(int n = 0; n < N; ++n)
{
e_t_n_host_result(t, topk_id, n) = ck::type_convert<EDataType>(c_t_k_n(t, topk_id, n));
}
}
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
auto status = ck::utils::check_err(
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1)
? 0
: 1;
if (status == 0){
printf("Validation Pass.\n");
}
return status;
}
return 0;
}

View File

@@ -39,7 +39,7 @@ using B0DataType = F8;
using B1DataType = F32;
using EDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F32;
using CShuffleDataType = EDataType;
using D2DataType = F32;
using DsDataType = ck::Tuple<D2DataType>;
@@ -63,9 +63,16 @@ struct MulABScaleExpertWeight
operator()<EDataType, float, float>(EDataType& e, const float& c, const float& d2) const
{
// for real kernel use
e = ck::type_convert<EDataType>(c * d2);
(void)d2;
e = ck::type_convert<EDataType>(c);
}
template <>
__host__ __device__ constexpr void operator()<EDataType, EDataType, float>(
EDataType& e, const EDataType& c, const float& d2) const
{
(void)d2;
e = ck::type_convert<EDataType>(c);
}
// for reference cpu
template <>
__host__ __device__ constexpr void
@@ -121,13 +128,13 @@ static constexpr ck::index_t Scale_Block_K = 128;
#if 0
static constexpr ck::index_t MPerBlock = 32;
static constexpr ck::index_t BLOCKSIZE = 256;
static constexpr ck::index_t MXDLPerWave = 1;
static constexpr ck::index_t NXDLPerWave = 1;
static constexpr ck::index_t MXDLPerWave = 2;
static constexpr ck::index_t NXDLPerWave = 2;
static constexpr ck::index_t NPerBlock = 128;
static constexpr ck::index_t MNPerXDL = 32;
static constexpr ck::index_t MNPerXDL = 16;
static constexpr ck::index_t KPerBlock = 256 / sizeof(A0DataType);
static constexpr ck::index_t CShuffleNLane = 32;
static constexpr ck::index_t CShuffleNLane = 16;
static constexpr ck::index_t CShuffleMLane = BLOCKSIZE / CShuffleNLane;
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
@@ -147,10 +154,10 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
AK1, BK1,
MNPerXDL, MNPerXDL,
MXDLPerWave, NXDLPerWave,
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
1, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, A0DataType>;
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
2, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, false, int32_t, A0DataType>;
#else
static constexpr ck::index_t MPerBlock = 128; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
@@ -164,8 +171,8 @@ static constexpr ck::index_t MPerBlock = 128; using DeviceOpInstance = ck::tenso
4, 4,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
2, 1, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, false, false, A0DataType>;
2, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, false, int32_t, A0DataType>;
#endif
// clang-format on
@@ -493,8 +500,8 @@ int main(int argc, char* argv[])
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceMoeGemm2BlockScale<float,
float,
D2DataType,
CShuffleDataType,
D2DataType,
AccDataType,
PassThrough,
PassThrough,

View File

@@ -194,183 +194,6 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3<BlockGemmPipelineS
__device__ static constexpr auto HotLoopScheduler()
{
#if 0
constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
constexpr auto num_buffer_load_inst_b = MWaves * HotLoopInstList::B_Buffer_Load_Inst_Num;
constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num;
constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat;
constexpr auto staged_num_mfma = num_mfma / MRepeat;
constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a;
if constexpr(stage.value == 0)
{
// B VMEM access.
constexpr auto staged_num_buffer_load_b_per_ds_read_a =
num_buffer_load_inst_b / staged_num_ds_read_inst_a;
constexpr auto staged_num_mfma_per_buffer_load_b =
staged_num_mfma / num_buffer_load_inst_b;
// B global
static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) {
ignore = i_inst;
static_for<0, staged_num_buffer_load_b_per_ds_read_a - 1, 1>{}([&](auto ibuf_inst) {
ignore = ibuf_inst;
static_for<0, staged_num_mfma_per_buffer_load_b, 1>{}([&](auto i_mfma) {
ignore = i_mfma;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
});
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
static_for<0, staged_num_mfma_per_buffer_load_b - 1, 1>{}([&](auto i_mfma) {
ignore = i_mfma;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
});
__builtin_amdgcn_sched_barrier(0);
}
else if constexpr(stage.value == 1)
{
// A LDS write access.
constexpr auto staged_num_mfma_per_ds_write_a =
math::integer_divide_ceil(staged_num_mfma, num_ds_write_inst_a);
constexpr auto stage_more_mfma =
staged_num_mfma - (staged_num_mfma_per_ds_write_a - 1) * num_ds_write_inst_a;
// A local write
static_for<0, num_ds_write_inst_a, 1>{}([&](auto i_inst) {
if constexpr(i_inst.value < stage_more_mfma)
{
if(i_inst.value < staged_num_ds_read_inst_a)
{
static_for<0, staged_num_mfma_per_ds_write_a - 1, 1>{}([&](auto i_mfma) {
ignore = i_mfma;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
}
else
{
static_for<0, staged_num_mfma_per_ds_write_a, 1>{}([&](auto i_mfma) {
ignore = i_mfma;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
}
}
else
{
if(i_inst.value < staged_num_ds_read_inst_a)
{
static_for<0, staged_num_mfma_per_ds_write_a - 2, 1>{}([&](auto i_mfma) {
ignore = i_mfma;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
}
else
{
static_for<0, staged_num_mfma_per_ds_write_a - 1, 1>{}([&](auto i_mfma) {
ignore = i_mfma;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
}
}
});
__builtin_amdgcn_sched_barrier(0);
}
else if constexpr(stage.value == 2)
{
// A VMEM access.
constexpr auto staged_num_mfma_per_buffer_load_a =
math::integer_divide_ceil(staged_num_mfma, num_buffer_load_inst_a);
constexpr auto stage_more_mfma =
staged_num_mfma - (staged_num_mfma_per_buffer_load_a - 1) * num_buffer_load_inst_a;
// A global
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i_inst) {
if constexpr(i_inst.value < stage_more_mfma)
{
if(i_inst.value < staged_num_ds_read_inst_a)
{
static_for<0, staged_num_mfma_per_buffer_load_a - 1, 1>{}([&](auto i_mfma) {
ignore = i_mfma;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
}
else
{
static_for<0, staged_num_mfma_per_buffer_load_a, 1>{}([&](auto i_mfma) {
ignore = i_mfma;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
}
else
{
if(i_inst.value < staged_num_ds_read_inst_a)
{
static_for<0, staged_num_mfma_per_buffer_load_a - 2, 1>{}([&](auto i_mfma) {
ignore = i_mfma;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
}
else
{
static_for<0, staged_num_mfma_per_buffer_load_a - 1, 1>{}([&](auto i_mfma) {
ignore = i_mfma;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
}
});
__builtin_amdgcn_sched_barrier(0);
}
else
{
// A local Read
static_for<0, staged_num_ds_read_inst_a, 1>{}([&](auto i_inst) {
ignore = i_inst;
static_for<0, staged_num_mfma_per_ds_read_a, 1>{}([&](auto i_mfma) {
ignore = i_mfma;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
__builtin_amdgcn_sched_barrier(0);
}
#elif 1
// A/B split schedule
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
constexpr auto num_ds_read_inst_a =
@@ -452,11 +275,22 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3<BlockGemmPipelineS
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
if constexpr((imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage - 1)) &&
(imfma < (num_mfma_perstage - 1)))
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
__builtin_amdgcn_sched_group_barrier(0x800, 2, 0); // v_pk_fma
// __builtin_amdgcn_sched_group_barrier(0x1000, 4, 0); // v_fmac
});
// Scale load, 1B
if constexpr(i.value == 0)
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
// Scale load, 1A
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
// __builtin_amdgcn_sched_barrier(0);
});
// A global read + A local write
@@ -481,11 +315,17 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3<BlockGemmPipelineS
{
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
}
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
if constexpr((imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage - 1)) &&
(imfma < (num_mfma_perstage - 1)))
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
__builtin_amdgcn_sched_group_barrier(0x800, 2, 0); // v_pk_fma
// __builtin_amdgcn_sched_group_barrier(0x1000, 4, 0); // v_fmac
});
// Scale load, 1A
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
// __builtin_amdgcn_sched_barrier(0);
});
// lds synchronization, prefetch next loop local A
@@ -493,13 +333,19 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3<BlockGemmPipelineS
ignore = i;
static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
if constexpr((imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage - 1)) &&
(imfma < (num_mfma_perstage - 1)))
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
__builtin_amdgcn_sched_group_barrier(0x800, 2, 0); // v_pk_fma
// __builtin_amdgcn_sched_group_barrier(0x1000, 4, 0); // v_fmac
});
// Scale load, 1A
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
// __builtin_amdgcn_sched_barrier(0);
});
#endif
}
template <bool HasMainLoop,
@@ -577,13 +423,18 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3<BlockGemmPipelineS
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
a_scale_thread_desc.GetElementSpaceSize());
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
b_scale_thread_desc.GetElementSpaceSize());
auto c_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
c_scale_thread_desc.GetElementSpaceSize());
StaticallyIndexedArray<decltype(a_scale_thread_buf), Number<2>{}> a_scale_thread_bufs;
StaticallyIndexedArray<decltype(b_scale_thread_buf), Number<2>{}> b_scale_thread_bufs;
// StaticallyIndexedArray<decltype(c_scale_thread_buf), Number<2>{}> c_scale_thread_bufs;
// Global prefetch A1 B1, AScale1 BScale1
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
@@ -601,7 +452,7 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3<BlockGemmPipelineS
a_scale_grid_buf,
a_scale_thread_desc,
make_tuple(m0, I0),
a_scale_thread_buf);
a_scale_thread_bufs(I0));
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<0>{}));
});
@@ -621,12 +472,12 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3<BlockGemmPipelineS
b_scale_grid_buf,
b_scale_thread_desc,
make_tuple(I0, I0),
b_scale_thread_buf);
b_scale_thread_bufs(I0));
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
c_scale_thread_buf(m0) = __builtin_elementwise_fma(a_scale_thread_buf[m0], b_scale_thread_buf[I0], .0f);
c_scale_thread_buf(m0) = a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs[I0][I0];
});
// Local prefill A1
@@ -636,12 +487,13 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3<BlockGemmPipelineS
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
#if 1
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc,
make_tuple(m0, I0),
a_scale_thread_buf);
a_scale_thread_bufs(I0));
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<0>{}));
});
@@ -661,16 +513,19 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3<BlockGemmPipelineS
b_scale_grid_buf,
b_scale_thread_desc,
make_tuple(I0, I0),
b_scale_thread_buf);
b_scale_thread_bufs(I0));
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
#endif
// Initialize C
c_thread_buf.Clear();
// Double register buffer for non-scaled gemm computation
// 1. Reduce register pressure
// 2. Decouple the dependency between mfma instruction and scale-fma instruction following.
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
AccDataType,
1,
2,
xdlops_gemm.GetRegSizePerXdlops(),
true>
c_thread_buf_per_scale;
@@ -681,7 +536,7 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3<BlockGemmPipelineS
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, KGroup, 1>{}([&](auto kg0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(m0, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
@@ -690,6 +545,32 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3<BlockGemmPipelineS
});
});
#if 1
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
// Fill first mfma buffer
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(I0, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) = b_thread_bufs
[I0][Number<b_thread_desc_.CalculateOffset(make_tuple(I0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
});
#endif
__builtin_amdgcn_sched_barrier(0);
// main body
@@ -710,6 +591,36 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3<BlockGemmPipelineS
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc,
make_tuple(m0, I0),
a_scale_thread_bufs(local_read_buf));
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{}));
});
if constexpr(NumKBlockPerScale == 1)
{
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{}));
}
else
{
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{}));
}
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc,
make_tuple(I0, I0),
b_scale_thread_bufs(local_read_buf));
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
vector_type<AccDataType, 2> c_scale_thread_vec;
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
@@ -718,10 +629,25 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3<BlockGemmPipelineS
c_scale_thread_buf[m0];
static_for<0, NRepeat, 1>{}([&](auto n0) {
constexpr auto mfma_buf_offset =
((m0 * NRepeat + n0 + 1) % 2) * xdlops_gemm.GetRegSizePerXdlops();
constexpr auto scale_buf_offset =
((m0 * NRepeat + n0) % 2) * xdlops_gemm.GetRegSizePerXdlops();
constexpr auto a_local_buf_offset =
((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) / NRepeat;
constexpr auto b_local_buf_offset =
((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) % NRepeat;
constexpr auto b_local_buf_id =
Number<mfma_reg_buf ^
((m0 * NRepeat + n0 + 1) / (MRepeat * NRepeat))>{};
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
c_thread_buf_per_scale
.GetVectorTypeReference(Number<mfma_buf_offset>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
@@ -729,7 +655,8 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3<BlockGemmPipelineS
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple((m0 + HotloopLocalBufSwitch * mfma_reg_buf) %
make_tuple((a_local_buf_offset +
HotloopLocalBufSwitch * mfma_reg_buf) %
2,
I0,
I0,
@@ -737,9 +664,9 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3<BlockGemmPipelineS
I0,
ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
b_thread_bufs
[b_local_buf_id][Number<b_thread_desc_.CalculateOffset(
make_tuple(b_local_buf_offset, I0, k0, ik))>{}];
});
using mfma_input_type =
@@ -749,7 +676,8 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3<BlockGemmPipelineS
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
c_thread_buf_per_scale.GetVectorTypeReference(
Number<mfma_buf_offset>{}));
});
constexpr index_t c_offset =
@@ -760,7 +688,8 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3<BlockGemmPipelineS
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
c_thread_buf_per_scale
.GetVectorTypeReference(Number<scale_buf_offset>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
@@ -779,7 +708,7 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3<BlockGemmPipelineS
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * 2 + kg0>{},
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(local_read_buf),
@@ -805,7 +734,7 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3<BlockGemmPipelineS
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * 2 + kg0>{},
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(local_read_buf),
@@ -831,7 +760,7 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3<BlockGemmPipelineS
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * 2 + kg0>{},
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(mfma_reg_buf),
@@ -849,45 +778,14 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3<BlockGemmPipelineS
});
}
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
c_scale_thread_buf(m0) = a_scale_thread_bufs[mfma_reg_buf][m0] *
b_scale_thread_bufs[mfma_reg_buf][I0];
});
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
static_for<0, MRepeat, 1>{}([&](auto m0) {
c_scale_thread_buf(m0) = __builtin_elementwise_fma(a_scale_thread_buf[m0], b_scale_thread_buf[I0], .0f);
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc,
make_tuple(m0, I0),
a_scale_thread_buf);
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{}));
});
if constexpr(NumKBlockPerScale == 1)
{
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{}));
}
else
{
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{}));
}
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc,
make_tuple(I0, I0),
b_scale_thread_buf);
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step);
// __builtin_amdgcn_sched_group_barrier(0x020, MRepeat + 1, 0); // VMEM read
__builtin_amdgcn_sched_barrier(0);
};
LoopFunc(I0, I1);
@@ -915,8 +813,21 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3<BlockGemmPipelineS
c_scale_thread_buf[m0];
static_for<0, NRepeat, 1>{}([&](auto n0) {
constexpr auto mfma_buf_offset =
((m0 * NRepeat + n0 + 1) % 2) * xdlops_gemm.GetRegSizePerXdlops();
constexpr auto scale_buf_offset =
((m0 * NRepeat + n0) % 2) * xdlops_gemm.GetRegSizePerXdlops();
constexpr auto a_local_buf_offset =
((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) / NRepeat;
constexpr auto b_local_buf_offset =
((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) % NRepeat;
constexpr auto b_local_buf_id =
Number<0 ^ ((m0 * NRepeat + n0 + 1) / (MRepeat * NRepeat))>{};
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
c_thread_buf_per_scale.GetVectorTypeReference(Number<mfma_buf_offset>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
static_for<0, KRepeat, 1>{}([&](auto k0) {
@@ -926,19 +837,19 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3<BlockGemmPipelineS
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0 % 2, I0, I0, k0, I0, ik))>{}];
make_tuple(a_local_buf_offset % 2, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
b_thread_bufs[b_local_buf_id][Number<b_thread_desc_.CalculateOffset(
make_tuple(b_local_buf_offset, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
xdlops_gemm.template Run<>(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(
Number<mfma_buf_offset>{}));
});
constexpr index_t c_offset =
@@ -949,7 +860,8 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3<BlockGemmPipelineS
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
c_thread_buf_per_scale
.GetVectorTypeReference(Number<scale_buf_offset>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
@@ -968,7 +880,7 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3<BlockGemmPipelineS
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * 2 + kg0>{},
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(I1),
@@ -988,7 +900,7 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3<BlockGemmPipelineS
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * 2 + kg0>{},
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(I1),
@@ -1008,7 +920,7 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3<BlockGemmPipelineS
make_tuple(Number<(m0 + 2) % MRepeat>{},
I0,
I0,
Number<k0 * 2 + kg0>{},
Number<k0 * KGroup + kg0>{},
I0,
I0),
a_block_buf.At(I0),
@@ -1024,7 +936,7 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3<BlockGemmPipelineS
HotLoopScheduler();
static_for<0, MRepeat, 1>{}([&](auto m0) {
c_scale_thread_buf(m0) = __builtin_elementwise_fma(a_scale_thread_buf[m0], b_scale_thread_buf[I0], .0f);
c_scale_thread_buf(m0) = a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs[I0][I0];
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
@@ -1035,31 +947,52 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3<BlockGemmPipelineS
c_scale_thread_buf[m0];
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
constexpr auto mfma_buf_offset =
((m0 * NRepeat + n0 + 1) % 2) * xdlops_gemm.GetRegSizePerXdlops();
constexpr auto scale_buf_offset =
((m0 * NRepeat + n0) % 2) * xdlops_gemm.GetRegSizePerXdlops();
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
(m0 + HotloopLocalBufSwitch) % 2, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
constexpr auto a_local_buf_offset =
((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) / NRepeat;
constexpr auto b_local_buf_offset =
((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) % NRepeat;
if constexpr(!((m0 == (MRepeat - 1)) && (n0 == (NRepeat - 1))))
{
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<mfma_buf_offset>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple((a_local_buf_offset + HotloopLocalBufSwitch) % 2,
I0,
I0,
k0,
I0,
ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(b_local_buf_offset, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(
Number<mfma_buf_offset>{}));
});
}
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
});
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
@@ -1068,7 +1001,8 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3<BlockGemmPipelineS
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
c_thread_buf_per_scale
.GetVectorTypeReference(Number<scale_buf_offset>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
@@ -1083,7 +1017,7 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3<BlockGemmPipelineS
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(
Number<m0 + 2>{}, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
Number<m0 + 2>{}, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(I1),
a_thread_desc_,
make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{},
@@ -1111,31 +1045,47 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3<BlockGemmPipelineS
c_scale_thread_buf[m0];
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
constexpr auto mfma_buf_offset =
((m0 * NRepeat + n0 + 1) % 2) * xdlops_gemm.GetRegSizePerXdlops();
constexpr auto scale_buf_offset =
((m0 * NRepeat + n0) % 2) * xdlops_gemm.GetRegSizePerXdlops();
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0 % 2, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
constexpr auto a_local_buf_offset =
((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) / NRepeat;
constexpr auto b_local_buf_offset =
((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) % NRepeat;
if constexpr(!((m0 == (MRepeat - 1)) && (n0 == (NRepeat - 1))))
{
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<mfma_buf_offset>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(a_local_buf_offset % 2, I0, I0, k0, I0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(b_local_buf_offset, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(
Number<mfma_buf_offset>{}));
});
}
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
});
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
@@ -1144,7 +1094,8 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3<BlockGemmPipelineS
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
c_thread_buf_per_scale
.GetVectorTypeReference(Number<scale_buf_offset>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
@@ -1159,7 +1110,7 @@ struct BlockwiseGemmXdlops_pipeline_blockscale_bpreshuffle_v3<BlockGemmPipelineS
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k0_k1_k2,
make_tuple(
Number<m0 + 2>{}, I0, I0, Number<k0 * 2 + kg0>{}, I0, I0),
Number<m0 + 2>{}, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
a_block_buf.At(I0),
a_thread_desc_,
make_tuple(

View File

@@ -3,9 +3,11 @@
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v1.hpp"
// #include
// "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_blockscale_b_preshuffle_v2.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_gufusion_v3.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v3.hpp"
namespace ck {
@@ -32,12 +34,15 @@ template <BlockGemmPipelineVersion BlkGemmPipelineVer,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack>
index_t KPack,
bool GUFusion = false>
constexpr auto BlockGemmBlockMoeScaleBPreshufflePipeline_Selector()
{
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
return BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1<
if constexpr(GUFusion)
{
return BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v1<
BlkGemmPipeSche,
BlockSize,
ADataType,
@@ -61,6 +66,34 @@ constexpr auto BlockGemmBlockMoeScaleBPreshufflePipeline_Selector()
MRepeat,
NRepeat,
KPack>{};
}
else
{
return BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1<
BlkGemmPipeSche,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MScaleBlock,
NScaleBlock,
KScaleBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
}
#if 0
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
@@ -91,30 +124,60 @@ constexpr auto BlockGemmBlockMoeScaleBPreshufflePipeline_Selector()
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
static_assert(MRepeat >= 4, "MRepeat should at least be 4 in BlockGemmPipelineVersion::v3");
return BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3<
BlkGemmPipeSche,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MScaleBlock,
NScaleBlock,
KScaleBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
if constexpr(GUFusion)
{
return BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_gufusion_v3<
BlkGemmPipeSche,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MScaleBlock,
NScaleBlock,
KScaleBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
else
{
return BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v3<
BlkGemmPipeSche,
BlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MScaleBlock,
NScaleBlock,
KScaleBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
}
else
{

View File

@@ -71,8 +71,11 @@ template <typename ALayout,
typename CDEShuffleBlockTransferScalarPerVectors,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
index_t ActivationOP = 0,
bool NSwizzle = false,
bool IsInputGemm = true,
bool MulRoutedWeight = false,
typename IndexType = index_t,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA,
typename LDSTypeA = ComputeTypeA,
@@ -146,7 +149,11 @@ struct DeviceMoeGemmBlockScale
CDEShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ActivationOP,
NSwizzle,
IsInputGemm,
MulRoutedWeight,
IndexType,
ComputeTypeA,
ComputeTypeB,
LDSTypeA,
@@ -154,6 +161,20 @@ struct DeviceMoeGemmBlockScale
using Argument = typename GridwiseGemm::Argument;
static constexpr index_t APackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<ADataType>, pk_i4_t>)
return 2;
else
return 1;
}();
static constexpr index_t BPackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
return 2;
else
return 1;
}();
int GetPreShuffleParameters() override { return NPerXDL; }
// Invoker
@@ -349,10 +370,10 @@ struct DeviceMoeGemmBlockScale
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
auto size_a_buffer =
a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType);
auto size_b_buffer =
b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType);
auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
sizeof(ADataType) / APackedSize;
auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
sizeof(BDataType) / BPackedSize;
const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
@@ -412,8 +433,7 @@ struct DeviceMoeGemmBlockScale
constexpr index_t minimum_occupancy = 2;
constexpr auto MemoryDataOp =
IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd;
constexpr auto MemoryDataOp = IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd;
#if CK_USE_ASM_MOE_STAGE2_BLOCKSCALE
(void)minimum_occupancy;
@@ -486,7 +506,6 @@ struct DeviceMoeGemmBlockScale
true,
MemoryDataOp,
minimum_occupancy,
IsInputGemm,
TailNumber::Odd>;
RunKernel(kernel);
}
@@ -496,7 +515,6 @@ struct DeviceMoeGemmBlockScale
true,
MemoryDataOp,
minimum_occupancy,
IsInputGemm,
TailNumber::Even>;
RunKernel(kernel);
}
@@ -511,7 +529,6 @@ struct DeviceMoeGemmBlockScale
true,
MemoryDataOp,
minimum_occupancy,
IsInputGemm,
TailNumber::Odd>;
RunKernel(kernel);
}
@@ -521,7 +538,6 @@ struct DeviceMoeGemmBlockScale
true,
MemoryDataOp,
minimum_occupancy,
IsInputGemm,
TailNumber::Even>;
RunKernel(kernel);
}
@@ -543,7 +559,6 @@ struct DeviceMoeGemmBlockScale
false,
MemoryDataOp,
minimum_occupancy,
IsInputGemm,
TailNumber::Odd>;
RunKernel(kernel);
}
@@ -553,7 +568,6 @@ struct DeviceMoeGemmBlockScale
false,
MemoryDataOp,
minimum_occupancy,
IsInputGemm,
TailNumber::Even>;
RunKernel(kernel);
}
@@ -567,7 +581,6 @@ struct DeviceMoeGemmBlockScale
false,
MemoryDataOp,
minimum_occupancy,
IsInputGemm,
TailNumber::Odd>;
RunKernel(kernel);
}
@@ -577,7 +590,6 @@ struct DeviceMoeGemmBlockScale
false,
MemoryDataOp,
minimum_occupancy,
IsInputGemm,
TailNumber::Even>;
RunKernel(kernel);
}

View File

@@ -1134,7 +1134,7 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
make_tuple(math::integer_divide_ceil(problem.M, ScaleBlockM),
math::integer_divide_ceil(problem.K, ScaleBlockK)),
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
make_tuple(1, math::integer_divide_ceil(problem.M, ScaleBlockM)));
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
math::integer_divide_ceil(problem.K, ScaleBlockK)),
@@ -1282,9 +1282,9 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle
decltype(a_scale_grid_desc_am_ak),
decltype(a_scale_thread_desc),
Sequence<1, ScaleSliceSizeK>,
Sequence<0, 1>,
Sequence<1, 0>,
0,
1,
ScaleSliceSizeK,
1,
false>(
a_scale_grid_desc_am_ak,
@@ -1630,7 +1630,7 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
make_tuple(math::integer_divide_ceil(problem.M, ScaleBlockM),
math::integer_divide_ceil(problem.K, ScaleBlockK)),
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
make_tuple(1, math::integer_divide_ceil(problem.M, ScaleBlockM)));
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
math::integer_divide_ceil(problem.K, ScaleBlockK)),
@@ -1784,9 +1784,9 @@ struct GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle
decltype(a_scale_grid_desc_am_ak),
decltype(a_scale_thread_desc),
Sequence<1, ScaleSliceSizeK>,
Sequence<0, 1>,
Sequence<1, 0>,
0,
1,
ScaleSliceSizeK,
1,
false>(
a_scale_grid_desc_am_ak,

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -26,12 +26,18 @@ namespace ck {
// two lds chunks.
// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
// buffer when we declare __shared__ inside blkgemmpipe
enum Activation
{
gelu_and_mul = 0,
silu_and_mul = 1
};
template <typename GridwiseGemm,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1,
bool IsInputGemm = false,
TailNumber TailNum = TailNumber::Full>
TailNumber TailNum = TailNumber::Even>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
@@ -44,7 +50,7 @@ __global__ void
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, IsInputGemm, TailNum>(
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_sorted_token_ids,
karg.p_sorted_expert_ids,
karg.p_max_token_id,
@@ -68,7 +74,6 @@ template <typename GridwiseGemm,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1,
bool IsInputGemm = false,
TailNumber TailNum = TailNumber::Even>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
@@ -83,8 +88,7 @@ __global__ void
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
GridwiseGemm::
template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, IsInputGemm, TailNum>(
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_sorted_token_ids,
karg.p_sorted_expert_ids,
karg.p_max_token_id,
@@ -154,7 +158,11 @@ template <typename ALayout,
typename CDEShuffleBlockTransferScalarPerVectors,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
index_t ActivationOperation = 0,
bool NSwizzle = false,
bool IsInputGemm = true,
bool MulRoutedWeight = true,
typename IndexType = index_t,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA,
typename LDSTypeA = ADataType,
@@ -309,7 +317,7 @@ struct GridwiseMoeGemmBlockScale
}
__host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
{
const auto a_grid_desc_mraw_kraw = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
@@ -501,8 +509,8 @@ struct GridwiseMoeGemmBlockScale
}
template <typename ELayout>
__host__ __device__ static auto
MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
__host__ __device__ static auto MakeCGridDescriptor_M_N(
IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
{
const auto c_grid_desc_mraw_nraw = [&]() {
if constexpr(is_same<tensor_layout::gemm::RowMajor, ELayout>::value)
@@ -925,7 +933,8 @@ struct GridwiseMoeGemmBlockScale
NPerXdl,
MXdlPerWave,
NXdlPerWave,
KPack>())>;
KPack,
IsInputGemm>())>;
__device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
@@ -1157,7 +1166,6 @@ struct GridwiseMoeGemmBlockScale
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
bool IsInputGemm = true,
TailNumber TailNum = TailNumber::Odd>
__device__ static void Run(const index_t* p_sorted_token_ids,
const index_t* p_sorted_expert_ids,
@@ -1198,7 +1206,7 @@ struct GridwiseMoeGemmBlockScale
math::integer_divide_ceil(problem.K, ScaleBlockK)),
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
make_tuple(math::integer_divide_ceil(problem.N , ScaleBlockN),
math::integer_divide_ceil(problem.K, ScaleBlockK)),
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
@@ -1247,7 +1255,7 @@ struct GridwiseMoeGemmBlockScale
if(token_pos >= max_token_id || token0 >= problem.NumTokens)
return;
StaticallyIndexedArray<index_t, AMRepeats> gather_offsets;
StaticallyIndexedArray<IndexType, AMRepeats> gather_offsets;
static_for<0, AMRepeats, 1>{}([&](auto m0) {
const index_t fused_token = p_sorted_token_ids[token_pos + m0];
index_t token_offset = fused_token & 0xffffff;
@@ -1255,11 +1263,11 @@ struct GridwiseMoeGemmBlockScale
{
token_offset = token_offset * problem.TopK + (fused_token >> 24);
}
gather_offsets(m0) = token_offset * problem.K;
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
});
const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K);
const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
const index_t expert_scale_stride =
__builtin_amdgcn_readfirstlane(math::integer_divide_ceil(problem.N, ScaleBlockN) *
__builtin_amdgcn_readfirstlane(math::integer_divide_ceil(problem.N, ScaleBlockN) * (IsInputGemm ? 2 : 1) *
math::integer_divide_ceil(problem.K, ScaleBlockK));
// N0, K0, Blocksize*KPack
@@ -1307,7 +1315,7 @@ struct GridwiseMoeGemmBlockScale
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
index_t,
IndexType,
1,
BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
@@ -1350,6 +1358,7 @@ struct GridwiseMoeGemmBlockScale
static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
decltype(c_thread_buf) c_thread_buf_up;
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
@@ -1431,38 +1440,115 @@ struct GridwiseMoeGemmBlockScale
constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, ScaleSliceSizeK);
constexpr auto NumKBlockPerScale = math::integer_divide_ceil(ScaleBlockK, KPerBlock);
if constexpr(IsInputGemm)
{
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid_up + expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
b_grid_desc_bpreshuffled.GetElementSpaceSize());
auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
BDataType,
BDataType,
decltype(b_grid_desc_bpreshuffled),
decltype(b_block_desc_bk0_n_bk1),
Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
Sequence<1, 2, 0, 3>,
3,
BBlockTransferSrcScalarPerVector,
BThreadTransferSrcResetCoordinateAfterRun,
true>(b_grid_desc_bpreshuffled,
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
const BScaleType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2 / BPackedSize;
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid_up + expert_id * expert_scale_stride,
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
auto b_scale_thread_copy_up =
ThreadwiseTensorSliceTransfer_v2<BScaleType,
BScaleType,
decltype(b_scale_grid_desc_bn_ak),
decltype(b_scale_thread_desc),
Sequence<ScaleSliceSizeN, ScaleSliceSizeK>,
Sequence<0, 1>,
1,
ScaleSliceSizeK,
1,
false>(
b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_bpreshuffled,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
b_grid_desc_bpreshuffled,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_blockwise_copy_up,
b_grid_buf,
b_grid_buf_up,
b_block_buf,
b_block_slice_copy_step,
c_scale_thread_desc,
c_thread_buf,
c_scale_thread_desc,
c_thread_buf,
c_thread_buf_up,
a_scale_grid_desc_am_ak,
a_scale_thread_desc,
a_scale_thread_copy,
a_scale_grid_buf,
a_scale_thread_slice_copy_step,
a_scale_grid_desc_am_ak,
a_scale_thread_desc,
a_scale_thread_copy,
a_scale_grid_buf,
a_scale_thread_slice_copy_step,
b_scale_grid_desc_bn_ak,
b_scale_thread_desc,
b_scale_thread_copy,
b_scale_grid_buf,
b_scale_thread_slice_copy_step,
b_scale_grid_desc_bn_ak,
b_scale_thread_desc,
b_scale_thread_copy,
b_scale_thread_copy_up,
b_scale_grid_buf,
b_scale_grid_buf_up,
b_scale_thread_slice_copy_step,
num_k_block_main_loop);
num_k_block_main_loop);
}
else
{
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_slice_copy_step,
b_grid_desc_bpreshuffled,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_slice_copy_step,
c_scale_thread_desc,
c_thread_buf,
a_scale_grid_desc_am_ak,
a_scale_thread_desc,
a_scale_thread_copy,
a_scale_grid_buf,
a_scale_thread_slice_copy_step,
b_scale_grid_desc_bn_ak,
b_scale_thread_desc,
b_scale_thread_copy,
b_scale_grid_buf,
b_scale_thread_slice_copy_step,
num_k_block_main_loop);
}
// shuffle C and write out
{
@@ -1478,7 +1564,7 @@ struct GridwiseMoeGemmBlockScale
blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
// c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp is only used to get lengths
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
@@ -1491,6 +1577,71 @@ struct GridwiseMoeGemmBlockScale
constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6);
constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7);
if constexpr(IsInputGemm) // gu fusion, elementwise
{
static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock);
static_assert(N4 == 4);
const index_t n1 = get_warp_local_1d_id() / MWave;
const index_t n3 = threadIdx.x % get_warp_size() / NPerXdl;
vector_type<float, 4> topk_weights;
static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
static_for<0, N2, 1>{}([&](auto n2) { // num_groups_per_blk
const index_t n_pos = block_n_id * NPerBlock + n0 * N1 * N2 * N3 * N4 +
n1 * N2 * N3 * N4 + n2 * N3 * N4 + n3 * N4;
if constexpr(MulRoutedWeight)
{
topk_weights = *c_style_pointer_cast<const vector_type<float, N4>*>(
p_ds_grid[I0] + n_pos);
}
// if((blockIdx.x == 0) && (blockIdx.y == 0)){printf("m0:%d, n_pos:%d\n", static_cast<int>(m0), n_pos);}
static_for<0, N4, 1>{}([&](auto n4) { // inst_group_size
constexpr index_t c_offset =
blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
make_tuple(m0, n0, n2 * N4 + n4));
constexpr auto cidx = Number<c_offset>{};
if constexpr(ActivationOperation == Activation::silu_and_mul)
{
float gate = c_thread_buf[cidx];
float up = c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.AsType<float>()[n4];
up = up * topk_weights.AsType<float>()[n4];
}
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
{
gate *= 16;
up *= 16;
}
tensor_operation::element_wise::Silu{}(gate, gate);
c_thread_buf(cidx) = gate * up;
}
else if(ActivationOperation == Activation::gelu_and_mul)
{
float gate = c_thread_buf[cidx];
float up = c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.AsType<float>()[n4];
up = up * topk_weights.AsType<float>()[n4];
}
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
{
gate *= 16;
up *= 16;
}
tensor_operation::element_wise::Gelu{}(gate, gate);
c_thread_buf(cidx) = gate * up;
}
});
});
});
});
}
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
@@ -1656,8 +1807,8 @@ struct GridwiseMoeGemmBlockScale
Sequence<true>,
uniform_sequence_gen_t<NumDTensor,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
index_t,
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
IndexType,
1, // ScatterDim
true, // OutputScatter: false, only use scatter weights
scatter_weight_idx // ScatterWeightIdx: ascale
@@ -1701,7 +1852,7 @@ struct GridwiseMoeGemmBlockScale
CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
StaticallyIndexedArray<index_t, EMRepeats> scatter_offsets;
StaticallyIndexedArray<IndexType, EMRepeats> scatter_offsets;
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
auto dstidx = sfc_cde_block.GetIndex(access_id);
@@ -1768,7 +1919,6 @@ struct GridwiseMoeGemmBlockScale
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
bool IsInputGemm = true,
TailNumber TailNum = TailNumber::Odd>
__device__ static void Run_2Lds(const index_t* p_sorted_token_ids,
const index_t* p_sorted_expert_ids,
@@ -1858,7 +2008,7 @@ struct GridwiseMoeGemmBlockScale
if(token_pos >= max_token_id || expert_block_id * MPerBlock >= max_token_id ||
token0 >= problem.NumTokens)
return;
StaticallyIndexedArray<index_t, AMRepeats>
StaticallyIndexedArray<IndexType, AMRepeats>
gather_offsets; //= p_sorted_token_ids[token_pos];
static_for<0, AMRepeats, 1>{}([&](auto m0) {
const index_t fused_token = p_sorted_token_ids[token_pos + m0];
@@ -1867,11 +2017,11 @@ struct GridwiseMoeGemmBlockScale
{
token_offset = token_offset * problem.TopK + (fused_token >> 24);
}
gather_offsets(m0) = token_offset * problem.K;
gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
});
const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K);
const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
const index_t expert_scale_stride =
__builtin_amdgcn_readfirstlane(math::integer_divide_ceil(problem.N, ScaleBlockN) *
__builtin_amdgcn_readfirstlane(math::integer_divide_ceil(problem.N , ScaleBlockN) * (IsInputGemm ? 2 : 1) *
math::integer_divide_ceil(problem.K, ScaleBlockK));
// N0, K0, Blocksize*KPack
const index_t n_block_data_idx_on_grid =
@@ -1918,7 +2068,7 @@ struct GridwiseMoeGemmBlockScale
1,
AThreadTransferSrcResetCoordinateAfterRun,
true,
index_t,
IndexType,
1,
BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
make_multi_index(0, 0, 0),
@@ -1967,6 +2117,7 @@ struct GridwiseMoeGemmBlockScale
static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
decltype(c_thread_buf) c_thread_buf_up;
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
@@ -2009,8 +2160,8 @@ struct GridwiseMoeGemmBlockScale
{
token_offset = token_offset * problem.TopK + (fused_token >> 24);
}
scale_gather_offsets(m0) =
token_offset * math::integer_divide_ceil(problem.K, ScaleBlockK);
scale_gather_offsets(m0) = static_cast<IndexType>(token_offset) *
math::integer_divide_ceil(problem.K, ScaleBlockK);
});
// printf("blkid: %d, tid:%d, a_thread_offset: %d, scale_gather_offsets: %d\n", block_m_id,
@@ -2050,33 +2201,105 @@ struct GridwiseMoeGemmBlockScale
constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, ScaleSliceSizeK);
constexpr auto NumKBlockPerScale = math::integer_divide_ceil(ScaleBlockK, KPerBlock);
if constexpr(IsInputGemm)
{
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_grid_up + expert_id * static_cast<long_index_t>(expert_stride) / BPackedSize,
b_grid_desc_bpreshuffled.GetElementSpaceSize());
auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
BDataType,
BDataType,
decltype(b_grid_desc_bpreshuffled),
decltype(b_block_desc_bk0_n_bk1),
Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
Sequence<1, 2, 0, 3>,
3,
BBlockTransferSrcScalarPerVector,
BThreadTransferSrcResetCoordinateAfterRun,
true>(b_grid_desc_bpreshuffled,
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
KPack / KGroup * (get_thread_local_1d_id() % warpSize)));
const BScaleType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2 / BPackedSize;
const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid_up + expert_id * expert_scale_stride / BPackedSize,
b_scale_grid_desc_bn_ak.GetElementSpaceSize());
auto b_scale_thread_copy_up =
ThreadwiseTensorSliceTransfer_v2<BScaleType,
BScaleType,
decltype(b_scale_grid_desc_bn_ak),
decltype(b_scale_thread_desc),
Sequence<ScaleSliceSizeN, ScaleSliceSizeK>,
Sequence<0, 1>,
1,
ScaleSliceSizeK,
1,
false>(
b_scale_grid_desc_bn_ak, make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_bufs,
a_block_slice_copy_step,
b_grid_desc_bpreshuffled,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_bufs,
b_block_slice_copy_step,
c_scale_thread_desc,
c_thread_buf,
a_scale_grid_desc_am_ak,
a_scale_thread_desc,
a_scale_thread_copy,
a_scale_grid_buf,
a_scale_thread_slice_copy_step,
b_scale_grid_desc_bn_ak,
b_scale_thread_desc,
b_scale_thread_copy,
b_scale_grid_buf,
b_scale_thread_slice_copy_step,
num_k_block_main_loop);
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_bufs,
a_block_slice_copy_step,
b_grid_desc_bpreshuffled,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_blockwise_copy_up,
b_grid_buf,
b_grid_buf_up,
b_block_bufs,
b_block_slice_copy_step,
c_scale_thread_desc,
c_thread_buf,
c_thread_buf_up,
a_scale_grid_desc_am_ak,
a_scale_thread_desc,
a_scale_thread_copy,
a_scale_grid_buf,
a_scale_thread_slice_copy_step,
b_scale_grid_desc_bn_ak,
b_scale_thread_desc,
b_scale_thread_copy,
b_scale_thread_copy_up,
b_scale_grid_buf,
b_scale_grid_buf_up,
b_scale_thread_slice_copy_step,
num_k_block_main_loop);
}
else
{
blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
a_grid_desc_ak0_m_ak1,
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_bufs,
a_block_slice_copy_step,
b_grid_desc_bpreshuffled,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_bufs,
b_block_slice_copy_step,
c_scale_thread_desc,
c_thread_buf,
a_scale_grid_desc_am_ak,
a_scale_thread_desc,
a_scale_thread_copy,
a_scale_grid_buf,
a_scale_thread_slice_copy_step,
b_scale_grid_desc_bn_ak,
b_scale_thread_desc,
b_scale_thread_copy,
b_scale_grid_buf,
b_scale_thread_slice_copy_step,
num_k_block_main_loop);
}
// shuffle C and write out
{
@@ -2106,6 +2329,71 @@ struct GridwiseMoeGemmBlockScale
constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6);
constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7);
if constexpr(IsInputGemm) // gu fusion, elementwise
{
static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock);
static_assert(N4 == 4);
const index_t n1 = get_warp_local_1d_id() / MWave;
const index_t n3 = threadIdx.x % get_warp_size() / NPerXdl;
vector_type<float, 4> topk_weights;
static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
static_for<0, N2, 1>{}([&](auto n2) { // num_groups_per_blk
const index_t n_pos = block_n_id * NPerBlock + n0 * N1 * N2 * N3 * N4 +
n1 * N2 * N3 * N4 + n2 * N3 * N4 + n3 * N4;
if constexpr(MulRoutedWeight)
{
topk_weights = *c_style_pointer_cast<const vector_type<float, N4>*>(
p_ds_grid[I0] + n_pos);
}
// if((blockIdx.x == 0) && (blockIdx.y == 0)){printf("m0:%d, n_pos:%d\n", static_cast<int>(m0), n_pos);}
static_for<0, N4, 1>{}([&](auto n4) { // inst_group_size
constexpr index_t c_offset =
blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
make_tuple(m0, n0, n2 * N4 + n4));
constexpr auto cidx = Number<c_offset>{};
if constexpr(ActivationOperation == Activation::silu_and_mul)
{
float gate = c_thread_buf[cidx];
float up = c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.AsType<float>()[n4];
up = up * topk_weights.AsType<float>()[n4];
}
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
{
gate *= 16;
up *= 16;
}
tensor_operation::element_wise::Silu{}(gate, gate);
c_thread_buf(cidx) = gate * up;
}
else if(ActivationOperation == Activation::gelu_and_mul)
{
float gate = c_thread_buf[cidx];
float up = c_thread_buf_up[cidx];
if constexpr(MulRoutedWeight)
{
gate = gate * topk_weights.AsType<float>()[n4];
up = up * topk_weights.AsType<float>()[n4];
}
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
{
gate *= 16;
up *= 16;
}
tensor_operation::element_wise::Gelu{}(gate, gate);
c_thread_buf(cidx) = gate * up;
}
});
});
});
});
}
constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock();
@@ -2270,8 +2558,8 @@ struct GridwiseMoeGemmBlockScale
Sequence<true>,
uniform_sequence_gen_t<NumDTensor,
false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
index_t,
Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
IndexType,
1, // ScatterDim
true, // OutputScatter: false, only use scatter weights
scatter_weight_idx // ScatterWeightIdx: ascale
@@ -2315,7 +2603,7 @@ struct GridwiseMoeGemmBlockScale
CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
StaticallyIndexedArray<index_t, EMRepeats>
StaticallyIndexedArray<IndexType, EMRepeats>
scatter_offsets; //= p_sorted_token_ids[c_token_pos];
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk

View File

@@ -288,12 +288,14 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
oob_val = oob_val & is_src_valid;
if(i.value == ScatterWeightIdx)
{
auto data_types = SrcDatas{};
using DataType = remove_cvref_t<decltype(data_types[i])>;
static_assert(SrcScalarPerVectors{}[Number<ScatterWeightIdx>{}] == 1,
"scatter weight dim, should only one vec");
constexpr auto iScatter =
SrcSpaceFillingCurve::GetIndex(iAccess)(Number<ScatterDim>{});
static_for<0, SrcScalarPerVector, 1>{}([&](auto j) {
src_vectors(i).template AsType<float>()(j) =
src_vectors(i).template AsType<DataType>()(j) =
scatter_weights(Number<iScatter>{});
});
}
@@ -547,8 +549,6 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
using dst_vector_t = typename remove_cvref_t<decltype(dst_vectors[i])>::type;
IndexType dst_offset = scatter_offset + (dst_coords_[i].GetOffset());
const bool is_dst_valid = dst_offset < dst_descs[i].GetElementSpaceSize();
// coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i],
// dst_coords_[i]);
constexpr InMemoryDataOperationEnum DstInMemOp =
static_cast<InMemoryDataOperationEnum>(DstInMemOps::At(i.value));
dst_bufs(i).template Update<DstInMemOp, dst_vector_t>(

View File

@@ -112,8 +112,9 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
#endif
return a_lds_block_desc;
#endif
}
template <typename Problem>

View File

@@ -0,0 +1,280 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <unordered_map>
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <typename ADataType,
typename BDataType,
typename CDataType,
typename D2DataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
index_t ActivationType_ = 0,
bool MulRoutedWeight = true,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA>
struct ReferenceMoeGemm1BlockScale : public device::BaseOperator
{
// Argument
static constexpr auto ActivationType = ActivationType_;
struct Argument : public device::BaseArgument
{
Argument(const Tensor<ck::index_t>& sorted_token_ids,
const Tensor<ck::index_t>& expert_ids,
const Tensor<ck::index_t>& max_token_id,
const index_t sorted_tile_size,
const Tensor<ADataType>& a_t_k,
const Tensor<BDataType>& b_e_n_k,
const Tensor<D2DataType>& d2,
Tensor<CDataType>& c_t_k_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
: sorted_token_ids_{sorted_token_ids},
expert_ids_{expert_ids},
max_token_id_{max_token_id},
sorted_tile_size_{sorted_tile_size},
a_t_k_{a_t_k},
b_e_n_k_{b_e_n_k},
d2_{d2},
c_t_k_n_{c_t_k_n},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
}
const Tensor<ck::index_t>& sorted_token_ids_;
const Tensor<ck::index_t>& expert_ids_;
const Tensor<ck::index_t>& max_token_id_;
index_t sorted_tile_size_;
const Tensor<ADataType>& a_t_k_;
const Tensor<BDataType>& b_e_n_k_;
const Tensor<D2DataType>& d2_;
Tensor<CDataType>& c_t_k_n_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceMoeGemm1BlockScale::Argument;
float Run(const Argument& arg)
{
static_assert(ActivationType < 2, "Not supported activation type");
const int full_n = arg.c_t_k_n_.mDesc.GetLengths()[2];
auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = arg.a_t_k_.mDesc.GetLengths()[1];
AccDataType v_acc_up{0};
ComputeTypeB v_b_up{0};
AccDataType v_acc{0};
ComputeTypeA v_a{0};
ComputeTypeB v_b{0};
const int t = arg.sorted_token_ids_(m) & 0xffffff;
const int topk_id = (arg.sorted_token_ids_(m) & 0xff000000) >> 24;
const int e = arg.expert_ids_(m / arg.sorted_tile_size_);
const int token_cnt = arg.a_t_k_.mDesc.GetLengths()[0];
D2DataType v_topk_w = arg.d2_(m, 0); // expert
if(t < token_cnt)
{
for(int k = 0; k < K; ++k)
{
if constexpr(is_same_v<ADataType, pk_i4_t>)
{
uint8_t i4x2 = arg.a_t_k_(t, k).data;
uint8_t i4 = 0;
if(k % 2 == 1)
i4 = (i4x2 >> 0) & 0xf;
else
i4 = (i4x2 >> 4) & 0xf;
#if CK_USE_PK4_LAYOUT_SHUFFLE
v_a = i4_to_f32_gfx9(i4);
#else
v_a = i4 - 8;
#endif
}
else
{
arg.a_element_op_(v_a, arg.a_t_k_(t, k));
}
// same for B matrix
if constexpr(is_same_v<BDataType, pk_i4_t>)
{
uint8_t i4x2 = arg.b_e_n_k_(e, k, n).data;
uint8_t i4x2_up = arg.b_e_n_k_(e, k, n + full_n).data;
uint8_t i4 = 0;
uint8_t i4_up = 0;
if(k % 2 == 1)
{
i4 = (i4x2 >> 0) & 0xf;
i4_up = (i4x2_up >> 0) & 0xf;
}
else
{
i4 = (i4x2 >> 4) & 0xf;
i4_up = (i4x2_up >> 4) & 0xf;
}
#if CK_USE_PK4_LAYOUT_SHUFFLE
v_b = i4_to_f32_gfx9(i4);
v_b_up = i4_to_f32_gfx9(i4_up);
#else
v_b = i4 - 8;
v_b_up = i4_up - 8;
#endif
}
else
{
arg.b_element_op_(v_b, arg.b_e_n_k_(e, k, n));
arg.b_element_op_(v_b_up, arg.b_e_n_k_(e, k, n + full_n));
}
v_acc +=
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
v_acc_up += ck::type_convert<AccDataType>(v_a) *
ck::type_convert<AccDataType>(v_b_up);
}
CDataType v_c{0};
CDataType v_c_up{0};
if constexpr(MulRoutedWeight)
{
v_acc *= v_topk_w;
v_acc_up *= v_topk_w;
}
arg.c_element_op_(v_c, v_acc);
arg.c_element_op_(v_c_up, v_acc_up);
if constexpr(ActivationType == 1)
{
if constexpr(is_same_v<BDataType, pk_i4_t>)
{
v_c_up *= 16;
v_c *= 16;
}
tensor_operation::element_wise::Silu{}(v_c, v_c);
arg.c_t_k_n_(t, topk_id, n) = v_c * v_c_up;
}
else if constexpr(ActivationType == 0)
{
if constexpr(is_same_v<BDataType, pk_i4_t>)
{
v_c_up *= 16;
v_c *= 16;
}
tensor_operation::element_wise::Gelu{}(v_c, v_c);
arg.c_t_k_n_(t, topk_id, n) = v_c * v_c_up;
}
}
};
const ck::index_t max_token_id = arg.max_token_id_(0);
make_ParallelTensorFunctor(f_mk_kn_mn, max_token_id, full_n)(
std::thread::hardware_concurrency());
return 0;
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<ck::index_t>& sorted_token_ids,
const Tensor<ck::index_t>& expert_ids,
const Tensor<ck::index_t>& max_token_id,
const index_t sorted_tile_size,
const Tensor<ADataType>& a_t_k,
const Tensor<BDataType>& b_e_n_k,
const Tensor<D2DataType>& d2,
Tensor<CDataType>& c_t_k_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{sorted_token_ids,
expert_ids,
max_token_id,
sorted_tile_size,
a_t_k,
b_e_n_k,
d2,
c_t_k_n,
a_element_op,
b_element_op,
c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceMoeGemm1BlaockScale"
<< std::endl;
// clang-format on
return str.str();
}
static float i4_to_f32_gfx9(uint8_t i4)
{
static std::unordered_map<uint8_t, float> u = {{0b1000, -0.5000f},
{0b1001, -0.4375f},
{0b1010, -0.3750f},
{0b1011, -0.3125f},
{0b1100, -0.2500f},
{0b1101, -0.1875f},
{0b1110, -0.1250f},
{0b1111, -0.0625f},
{0b0, +0.0000f},
{0b1, +0.0625f},
{0b10, +0.1250f},
{0b11, +0.1875f},
{0b100, +0.2500f},
{0b101, +0.3125f},
{0b110, +0.3750f},
{0b111, +0.4375f}};
return u[i4];
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck