mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
Merge branch 'ck_moe_bs_splitk' into ck_moe_bs_splitk_pr
This commit is contained in:
@@ -18,6 +18,7 @@ add_example_executable(example_moe_gemm1_xdl_fp8 moe_gemm1_xdl_fp8.cpp)
|
||||
add_example_executable(example_moe_gemm2_xdl_fp8 moe_gemm2_xdl_fp8.cpp)
|
||||
add_example_executable(example_moe_gemm2_xdl_fp8_blockscale moe_gemm2_xdl_fp8_blockscale.cpp)
|
||||
add_example_executable(example_moe_gemm1_xdl_fp8_blockscale moe_gemm1_xdl_fp8_blockscale.cpp)
|
||||
add_example_executable(example_moe_gemm1_xdl_fp8_blockscale_splitk moe_gemm1_xdl_fp8_blockscale_splitk.cpp)
|
||||
|
||||
list(APPEND gpu_list gfx942 gfx950 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1153 gfx1200 gfx1201 gfx11-generic gfx12-generic)
|
||||
set(target 0)
|
||||
|
||||
@@ -0,0 +1,537 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_moe_gemm1_blockscale_splitk.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
|
||||
#include "ck/utility/blkgemmpipe_scheduler.hpp"
|
||||
|
||||
using ::ck::DeviceMem;
|
||||
using ::ck::HostTensorDescriptor;
|
||||
using ::ck::Tensor;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F8 = ck::f8_t;
|
||||
using F32 = float;
|
||||
using I64 = int64_t;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
using Bypass = ck::tensor_layout::BypassLayoutVerification;
|
||||
|
||||
using A0DataType = F8;
|
||||
using A1DataType = F32;
|
||||
using B0DataType = F8;
|
||||
using B1DataType = F32;
|
||||
using EDataType = F32;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = EDataType;
|
||||
using D2DataType = F32;
|
||||
using DsDataType = ck::Tuple<D2DataType>;
|
||||
|
||||
using A0Layout = Row;
|
||||
using B0Layout = Col;
|
||||
using ELayout = Row;
|
||||
using D0Layout = Row;
|
||||
using D1Layout = Col;
|
||||
using D2Layout = ELayout;
|
||||
using DsLayout = ck::Tuple<D2Layout>;
|
||||
|
||||
|
||||
struct MulABScaleExpertWeight
|
||||
{
|
||||
template <typename E, typename C, typename D2>
|
||||
__host__ __device__ constexpr void operator()(E& e, const C& c, const D2& d2) const;
|
||||
// for real kernel use
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<EDataType, EDataType, float>(EDataType& e, const EDataType& c, const float& d2) const
|
||||
{
|
||||
(void)d2;
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
};
|
||||
|
||||
void preShuffleBuffer(const B0DataType* src, B0DataType* dst, int N, int K, int NXdl)
|
||||
{
|
||||
int KPack = 16 / sizeof(B0DataType);
|
||||
int NLane = NXdl;
|
||||
int KLane = 64 / NLane;
|
||||
|
||||
int K0 = K / (KLane * KPack);
|
||||
// K -> K0 KLane KPack
|
||||
// N -> N0 NLane
|
||||
// N, K -> N0 K0 KLane NLane KPack
|
||||
int tempk;
|
||||
for(I64 n = 0; n < N; ++n)
|
||||
{
|
||||
for(I64 k = 0; k < K; ++k)
|
||||
{
|
||||
I64 n0 = n / NLane;
|
||||
I64 n1 = n % NLane;
|
||||
|
||||
I64 k0 = k / (KLane * KPack);
|
||||
tempk = k % (KLane * KPack);
|
||||
I64 k1 = tempk / KPack;
|
||||
I64 k2 = tempk % KPack;
|
||||
|
||||
I64 outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
|
||||
k1 * KPack * NLane + n1 * KPack + k2;
|
||||
|
||||
dst[outputIndex] = src[n * static_cast<I64>(K) + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = MulABScaleExpertWeight;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
static constexpr ck::index_t Scale_Block_M = 1;
|
||||
static constexpr ck::index_t Scale_Block_N = 128;
|
||||
static constexpr ck::index_t Scale_Block_K = 128;
|
||||
|
||||
static constexpr ck::index_t Nswizzle = false;
|
||||
static constexpr ck::index_t IsInputGemm= true; //splitk gemm1 goes to gemm2 pipeline.
|
||||
static constexpr ck::index_t IsSplitK = true; //splitk gemm1
|
||||
static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul
|
||||
static constexpr bool MulRoutedWeight = false; //splitk gemm1 does not do routedWeight.
|
||||
|
||||
#if 1
|
||||
static constexpr ck::index_t MPerBlock = 32;
|
||||
static constexpr ck::index_t NPerBlock = 128;
|
||||
static constexpr ck::index_t MNPerXDL = 16;
|
||||
static constexpr ck::index_t MXDLPerWave = MPerBlock / (MNPerXDL * 1);
|
||||
static constexpr ck::index_t NXDLPerWave = NPerBlock / (MNPerXDL * 4);
|
||||
static constexpr ck::index_t CShuffleMXDLPerWave = MXDLPerWave;
|
||||
static constexpr ck::index_t CShuffleNXDLPerWave = NXDLPerWave;
|
||||
static constexpr ck::index_t BLOCKSIZE = 256;
|
||||
|
||||
static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
|
||||
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
|
||||
static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
|
||||
static constexpr ck::index_t EVec = 16 / sizeof(EDataType);
|
||||
static constexpr ck::index_t D0Vec = 1;
|
||||
static constexpr ck::index_t D1Vec = 1;
|
||||
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale
|
||||
// clang-format off
|
||||
< Row, Col, DsLayout, ELayout,
|
||||
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
//threadnum, mblock, nblock, kblock
|
||||
BLOCKSIZE, Scale_Block_M, Scale_Block_N, Scale_Block_K,
|
||||
MPerBlock, NPerBlock, KPerBlock,
|
||||
// ak1, bk1
|
||||
AK1, BK1,
|
||||
// mn_perxdl
|
||||
MNPerXDL, MNPerXDL,
|
||||
// mn_xdlperwave
|
||||
MXDLPerWave, NXDLPerWave,
|
||||
// a,b: loadtranfer cluster, cluster order, srcorder,VECDIM, srcpervec, dstpervec, lds_extra
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0,
|
||||
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
CShuffleMXDLPerWave, CShuffleNXDLPerWave, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, IsInputGemm, IsSplitK, MulRoutedWeight, int32_t, A0DataType>;
|
||||
#else
|
||||
|
||||
static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
|
||||
Row, Col, DsLayout, ELayout,
|
||||
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
256, Scale_Block_M, Scale_Block_N, Scale_Block_K,
|
||||
MPerBlock, 128, 128,
|
||||
16, 16,
|
||||
16, 16,
|
||||
4, 2,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
4, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, IsInputGemm, IsSplitK, MulRoutedWeight, int32_t, A0DataType>;
|
||||
#endif
|
||||
// clang-format on
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
#if 1
|
||||
// GEMM shape
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 6144;
|
||||
// ck::index_t N = 128;
|
||||
// ck::index_t K = 512;
|
||||
ck::index_t experts = 8;
|
||||
ck::index_t topk = 2;
|
||||
// ck::index_t sorted_tile_num = 515;
|
||||
// ck::index_t valid_tile_num = 512;
|
||||
// ck::index_t tokens = 208;
|
||||
// ck::index_t sorted_tile_num = 15;
|
||||
// ck::index_t valid_tile_num = 13;
|
||||
// ck::index_t sorted_tile_num = 259;
|
||||
// ck::index_t valid_tile_num = 256;
|
||||
// ck::index_t tokens = 4096;
|
||||
ck::index_t sorted_tile_num = 2;
|
||||
ck::index_t valid_tile_num = 2;
|
||||
ck::index_t tokens = 32;
|
||||
#else
|
||||
// deepseek
|
||||
ck::index_t N = 2048;
|
||||
ck::index_t K = 7168;
|
||||
ck::index_t experts = 256;
|
||||
ck::index_t topk = 8;
|
||||
ck::index_t tokens = 4096;
|
||||
ck::index_t sorted_tile_num = 261;
|
||||
ck::index_t valid_tile_num = 256;
|
||||
#endif
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
// use default case
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 7)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
N = std::stoi(argv[4]);
|
||||
K = std::stoi(argv[5]);
|
||||
tokens = std::stoi(argv[6]);
|
||||
}
|
||||
else if(argc == 9)
|
||||
{
|
||||
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
N = std::stoi(argv[4]);
|
||||
K = std::stoi(argv[5]);
|
||||
tokens = std::stoi(argv[6]);
|
||||
sorted_tile_num = std::stoi(argv[7]);
|
||||
valid_tile_num = std::stoi(argv[8]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 6: N, K, tokens\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
|
||||
ck::index_t valid_size = valid_tile_num * MPerBlock;
|
||||
if(tokens * topk > valid_size)
|
||||
{
|
||||
printf("err config, tokens * topk > valid_size\n");
|
||||
exit(-1);
|
||||
}
|
||||
ck::index_t StrideA = K;
|
||||
ck::index_t StrideB = K;
|
||||
ck::index_t StrideE = N * 2;
|
||||
constexpr ck::index_t NumDTensor = DsDataType::Size();
|
||||
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0};
|
||||
ck::index_t Scale_Stride_AM = (K + Scale_Block_K - 1) / Scale_Block_K;
|
||||
ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K;
|
||||
ck::index_t Scale_Stride_B = (N + Scale_Block_N - 1) / Scale_Block_N * 2;
|
||||
|
||||
ck::index_t KBatch = 6;
|
||||
|
||||
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
|
||||
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
|
||||
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1 + sorted_tile_num}));
|
||||
max_token_id.mData = {valid_size};
|
||||
// int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3};
|
||||
for(int i = 0; i < sorted_tile_num; i++)
|
||||
{
|
||||
expert_ids.mData[i] = i / ck::math::integer_divide_ceil(valid_tile_num, experts);
|
||||
}
|
||||
|
||||
int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num;
|
||||
int tokenid = 0;
|
||||
|
||||
for(int i = 0; i < sorted_size; i++)
|
||||
{
|
||||
int tile_off = i % MPerBlock;
|
||||
if(tile_off < token_per_tile && tokenid < tokens * topk)
|
||||
{
|
||||
sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24);
|
||||
tokenid++;
|
||||
}
|
||||
else
|
||||
{
|
||||
sorted_token_ids.mData[i] = tokens;
|
||||
}
|
||||
}
|
||||
Tensor<A0DataType> a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1}));
|
||||
Tensor<A1DataType> a1_t_k(HostTensorDescriptor(
|
||||
{tokens, (K + Scale_Block_K - 1) / Scale_Block_K}, {Scale_Stride_AM, 1}, Row{}));
|
||||
Tensor<B0DataType> b0_e_n_k(
|
||||
HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{}));
|
||||
Tensor<B1DataType> b1_e_n_k(
|
||||
HostTensorDescriptor({experts,
|
||||
(K + Scale_Block_K - 1) / Scale_Block_K,
|
||||
(N + Scale_Block_N - 1) / Scale_Block_N * 2},
|
||||
{(Scale_Stride_B * Scale_Stride_BN), 1, Scale_Stride_BN},
|
||||
Col{}));
|
||||
Tensor<B0DataType> b0_preshuffled(
|
||||
HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{}));
|
||||
Tensor<EDataType> e_t_n_host_result(
|
||||
HostTensorDescriptor({tokens, topk, N * 2}, {topk * N * 2, N * 2, 1}, Row{}));
|
||||
Tensor<EDataType> e_t_n_device_result(
|
||||
HostTensorDescriptor({tokens, topk, N * 2}, {topk * N * 2, N * 2, 1}, Row{}));
|
||||
e_t_n_device_result.SetZero();
|
||||
std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl;
|
||||
std::cout << "a1_t_k: " << a1_t_k.mDesc << std::endl;
|
||||
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
|
||||
std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl;
|
||||
std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-1.0, 1.0});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-1.0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
|
||||
break;
|
||||
case 2:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
break;
|
||||
case 3:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
|
||||
break;
|
||||
case 4:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
|
||||
break;
|
||||
case 5:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
|
||||
break;
|
||||
case 6:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
|
||||
break;
|
||||
default:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
|
||||
}
|
||||
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) *
|
||||
sorted_token_ids.mDesc.GetElementSpaceSize());
|
||||
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize());
|
||||
DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize());
|
||||
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem a1_device_buf(sizeof(A1DataType) * a1_t_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem b1_device_buf(sizeof(B1DataType) * b1_e_n_k.mDesc.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize());
|
||||
|
||||
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
|
||||
expert_ids_dev.ToDevice(expert_ids.mData.data());
|
||||
max_token_id_dev.ToDevice(max_token_id.mData.data());
|
||||
a0_device_buf.ToDevice(a0_t_k.mData.data());
|
||||
a1_device_buf.ToDevice(a1_t_k.mData.data());
|
||||
b1_device_buf.ToDevice(b1_e_n_k.mData.data());
|
||||
e_device_buf.ToDevice(e_t_n_device_result.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto device_op = DeviceOpInstance{};
|
||||
|
||||
int NPerXdl = device_op.GetPreShuffleParameters();
|
||||
|
||||
preShuffleBuffer(
|
||||
b0_e_n_k.mData.data(), b0_preshuffled.mData.data(), N * 2 * experts, K, NPerXdl);
|
||||
|
||||
b0_device_buf.ToDevice(b0_preshuffled.mData.data());
|
||||
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
auto argument =
|
||||
device_op.MakeArgument(sorted_token_ids_dev.GetDeviceBuffer(),
|
||||
expert_ids_dev.GetDeviceBuffer(),
|
||||
max_token_id_dev.GetDeviceBuffer(),
|
||||
a0_device_buf.GetDeviceBuffer(),
|
||||
b0_device_buf.GetDeviceBuffer(),
|
||||
std::array<const void*, NumDTensor>{nullptr},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
tokens,
|
||||
topk,
|
||||
sorted_size,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
a1_device_buf.GetDeviceBuffer(),
|
||||
b1_device_buf.GetDeviceBuffer(),
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
if(time_kernel)
|
||||
{
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop = std::size_t(2) * tokens * topk * N * 2 * K;
|
||||
std::size_t num_btype = sizeof(A0DataType) * valid_tile_num * K +
|
||||
sizeof(B0DataType) * K * N * 2 * experts +
|
||||
sizeof(EDataType) * valid_tile_num * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s.\n"
|
||||
<< device_op.GetTypeString() << std::endl;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
// use atomic, so need to reinit outputs
|
||||
e_device_buf.ToDevice(e_t_n_device_result.mData.data());
|
||||
invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1});
|
||||
|
||||
Tensor<float> a_t_k({tokens, K});
|
||||
Tensor<float> b_e_n_k({experts, K, N * 2});
|
||||
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
|
||||
|
||||
Tensor<float> c_t_k_n({tokens, topk, N * 2}, {topk * N * 2, N * 2, 1}, Row{});
|
||||
|
||||
// handle scale before ref.
|
||||
for(int t = 0; t < tokens; ++t)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
a_t_k(t, k) = ck::type_convert<float>(a0_t_k(t, k)) * a1_t_k(t, k / Scale_Block_K);
|
||||
}
|
||||
}
|
||||
|
||||
for(int e = 0; e < experts; ++e)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
for(int n = 0; n < N * 2; ++n)
|
||||
{
|
||||
b_e_n_k(e, k, n) = ck::type_convert<float>(b0_e_n_k(e, k, n)) *
|
||||
b1_e_n_k(e, k / Scale_Block_K, n / Scale_Block_N);
|
||||
}
|
||||
}
|
||||
}
|
||||
using ReferenceGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceMoeGemm1BlockScaleSplitK<float,
|
||||
float,
|
||||
float,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>;
|
||||
auto ref_moe_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_moe_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids,
|
||||
expert_ids,
|
||||
max_token_id,
|
||||
MPerBlock,
|
||||
a_t_k,
|
||||
b_e_n_k,
|
||||
c_t_k_n,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
for(int m = 0; m < valid_size; ++m)
|
||||
{
|
||||
|
||||
const int fuse_t = sorted_token_ids.mData[m];
|
||||
const int t = fuse_t & 0xffffff;
|
||||
const int topk_id = (fuse_t & 0xff000000) >> 24;
|
||||
|
||||
if(t >= tokens)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
e_t_n_host_result(t, topk_id, n) =
|
||||
ck::type_convert<EDataType>(c_t_k_n(t, topk_id, n));
|
||||
}
|
||||
}
|
||||
|
||||
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
|
||||
|
||||
auto status =
|
||||
ck::utils::check_err(
|
||||
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1)
|
||||
? 0
|
||||
: 1;
|
||||
if(status == 0)
|
||||
{
|
||||
printf("Validation Pass.\n");
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -74,6 +74,7 @@ template <typename ALayout,
|
||||
index_t ActivationOP = 0,
|
||||
bool NSwizzle = false,
|
||||
bool IsInputGemm = true,
|
||||
bool IsSplitK = false,
|
||||
bool MulRoutedWeight = false,
|
||||
typename IndexType = index_t,
|
||||
typename ComputeTypeA = CDataType,
|
||||
@@ -156,6 +157,7 @@ struct DeviceMoeGemmBlockScale
|
||||
ActivationOP,
|
||||
NSwizzle,
|
||||
IsInputGemm,
|
||||
IsSplitK,
|
||||
MulRoutedWeight,
|
||||
IndexType,
|
||||
ComputeTypeA,
|
||||
@@ -201,12 +203,11 @@ struct DeviceMoeGemmBlockScale
|
||||
}
|
||||
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
|
||||
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N * (IsInputGemm && IsSplitK ? 2 : 1), arg.K, arg.KBatch);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
index_t k_grain = arg.KBatch * KPerBlock;
|
||||
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
|
||||
index_t K_split = arg.KBatch == 1 ? arg.K : arg.KBatch * KPerBlock;
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
const auto RunKernel = [&](const auto& kernel) {
|
||||
@@ -249,11 +250,11 @@ struct DeviceMoeGemmBlockScale
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(arg_.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
|
||||
0,
|
||||
arg_.M * arg_.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
// if(arg_.KBatch > 1)
|
||||
// hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
|
||||
// 0,
|
||||
// arg_.M * arg_.N * sizeof(CDataType) * (IsInputGemm && IsSplitK ? 2 : 1),
|
||||
// stream_config.stream_id_));
|
||||
};
|
||||
|
||||
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
|
||||
@@ -267,11 +268,11 @@ struct DeviceMoeGemmBlockScale
|
||||
}
|
||||
else
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
|
||||
0,
|
||||
arg.M * arg.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
// if(arg.KBatch > 1)
|
||||
// hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
|
||||
// 0,
|
||||
// arg.M * arg.N * sizeof(CDataType) * (IsInputGemm && IsSplitK ? 2 : 1),
|
||||
// stream_config.stream_id_));
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
|
||||
@@ -290,7 +291,7 @@ struct DeviceMoeGemmBlockScale
|
||||
constexpr index_t minimum_occupancy = (estimated_reg_total >= 256) ? 1 : 2;
|
||||
|
||||
constexpr auto MemoryDataOp =
|
||||
IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd;
|
||||
(IsInputGemm && !IsSplitK) ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd;
|
||||
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
@@ -416,8 +417,8 @@ struct DeviceMoeGemmBlockScale
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
// only impl kbatch 1 now
|
||||
if(arg.KBatch > 1)
|
||||
// only impl kbatch 1 for fp32
|
||||
if(arg.KBatch > 1 && !std::is_same_v<CDataType, float>)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
@@ -441,6 +442,11 @@ struct DeviceMoeGemmBlockScale
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if (arg.KBatch > 1 && arg.K % (KPerBlock * arg.KBatch) != 0)
|
||||
{
|
||||
// Not support Kpadding with KBatch > 1
|
||||
return false;
|
||||
}
|
||||
|
||||
if(get_warp_size() == 64)
|
||||
{
|
||||
|
||||
@@ -51,7 +51,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_sorted_token_ids,
|
||||
karg.p_sorted_expert_ids,
|
||||
@@ -60,8 +60,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_ds_grid,
|
||||
karg.p_c_grid,
|
||||
karg.p_a_scale_grid,
|
||||
karg.p_b_scale_grid,
|
||||
karg.p_a_scale_grid + splitk_batch_offset.ascale_k_split_offset,
|
||||
karg.p_b_scale_grid + splitk_batch_offset.bscale_k_split_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
@@ -101,8 +101,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_ds_grid,
|
||||
karg.p_c_grid,
|
||||
karg.p_a_scale_grid,
|
||||
karg.p_b_scale_grid,
|
||||
karg.p_a_scale_grid + splitk_batch_offset.ascale_k_split_offset,
|
||||
karg.p_b_scale_grid + splitk_batch_offset.bscale_k_split_offset,
|
||||
p_shared,
|
||||
p_shared1,
|
||||
karg,
|
||||
@@ -167,6 +167,7 @@ template <typename ALayout,
|
||||
index_t ActivationOperation = 0,
|
||||
bool NSwizzle = false,
|
||||
bool IsInputGemm = true,
|
||||
bool IsSplitK = false,
|
||||
bool MulRoutedWeight = true,
|
||||
typename IndexType = index_t,
|
||||
typename ComputeTypeA = CDataType,
|
||||
@@ -249,13 +250,15 @@ struct GridwiseMoeGemmBlockScale
|
||||
return 1;
|
||||
}();
|
||||
|
||||
__host__ static auto CalculateGridSize(index_t M, index_t N)
|
||||
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t K, index_t KBatch)
|
||||
{
|
||||
const index_t nblock = math::integer_divide_ceil(N, NPerBlock);
|
||||
const index_t mblock = math::integer_divide_ceil(M, MPerBlock);
|
||||
const index_t gridx = NSwizzle ? nblock * mblock : nblock;
|
||||
const index_t gridy = NSwizzle ? 1 : mblock;
|
||||
return std::make_tuple(gridx, gridy, 1);
|
||||
const index_t gridz = KBatch == 1 ? 1 : math::integer_divide_ceil(K, KPerBlock * KBatch);
|
||||
|
||||
return std::make_tuple(gridx, gridy, gridz);
|
||||
}
|
||||
|
||||
__host__ __device__ static auto CalculateMPadded(index_t M)
|
||||
@@ -284,27 +287,31 @@ struct GridwiseMoeGemmBlockScale
|
||||
|
||||
__host__ __device__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
|
||||
{
|
||||
auto K_t = K_Batch * KPerBlock;
|
||||
return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
|
||||
// auto K_t = K_Batch * KPerBlock;
|
||||
// return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
|
||||
return K_Batch == 1 ? K / AK1Value : K_Batch * KPerBlock / AK1Value;
|
||||
}
|
||||
|
||||
__host__ __device__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
|
||||
{
|
||||
auto K_t = K_Batch * KPerBlock;
|
||||
return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
|
||||
// auto K_t = K_Batch * KPerBlock;
|
||||
// return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
|
||||
return K_Batch == 1 ? K / BK1Value : K_Batch * KPerBlock / BK1Value;
|
||||
}
|
||||
|
||||
__host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
|
||||
{
|
||||
auto K_t = K_Batch * KPerBlock;
|
||||
return (K + K_t - 1) / K_t * KPerBlock;
|
||||
// auto K_t = K_Batch * KPerBlock;
|
||||
// return (K + K_t - 1) / K_t * KPerBlock;
|
||||
return K_Batch == 1 ? K : K_Batch * KPerBlock;
|
||||
}
|
||||
|
||||
__host__ __device__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
|
||||
{
|
||||
constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
|
||||
auto K_t = K_Batch * KReadVec;
|
||||
return (K + K_t - 1) / K_t * KReadVec;
|
||||
// auto K_t = K_Batch * KReadVec;
|
||||
// return (K + K_t - 1) / K_t * KReadVec;
|
||||
return K_Batch == 1 ? math::integer_divide_ceil(K, KReadVec) * KReadVec : K_Batch * KPerBlock;
|
||||
}
|
||||
|
||||
__host__ __device__ static auto CalculateMBlock(index_t M)
|
||||
@@ -409,7 +416,6 @@ struct GridwiseMoeGemmBlockScale
|
||||
make_pass_through_transform(M)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return a_grid_desc_ak0_m_ak1;
|
||||
}
|
||||
}
|
||||
@@ -742,34 +748,41 @@ struct GridwiseMoeGemmBlockScale
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
{
|
||||
a_k_split_offset = k_id * karg.KRead / APackedSize;
|
||||
ascale_k_split_offset = math::integer_divide_floor(a_k_split_offset, ScaleBlockK);
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
{
|
||||
a_k_split_offset = k_id * karg.KRead * karg.StrideA;
|
||||
ascale_k_split_offset = math::integer_divide_floor(a_k_split_offset, ScaleBlockK);
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
b_k_split_offset = k_id * karg.KRead * karg.StrideB;
|
||||
bscale_k_split_offset = math::integer_divide_floor(b_k_split_offset, ScaleBlockK);
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
// KPack * NLane * KLane * K0 * N0
|
||||
b_k_split_offset = k_id * karg.KRead * NLane / BPackedSize;
|
||||
bscale_k_split_offset = k_id * karg.KRead / ScaleBlockK + k_id * NLane / ScaleBlockN;
|
||||
}
|
||||
|
||||
if(k_id < karg.KBatch - 1)
|
||||
{
|
||||
karg.K = karg.KRead;
|
||||
}
|
||||
else
|
||||
{
|
||||
karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
|
||||
}
|
||||
|
||||
|
||||
// if(k_id < karg.KBatch - 1)
|
||||
// {
|
||||
// karg.K = karg.KRead;
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
|
||||
// }
|
||||
}
|
||||
|
||||
index_t a_k_split_offset;
|
||||
index_t b_k_split_offset;
|
||||
index_t ascale_k_split_offset;
|
||||
index_t bscale_k_split_offset;
|
||||
};
|
||||
|
||||
__device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
@@ -939,7 +952,7 @@ struct GridwiseMoeGemmBlockScale
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
KPack,
|
||||
IsInputGemm>())>;
|
||||
IsInputGemm && !IsSplitK>())>;
|
||||
|
||||
__device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
@@ -1190,7 +1203,7 @@ struct GridwiseMoeGemmBlockScale
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
ignore = b_element_op;
|
||||
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
|
||||
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N * (IsInputGemm && IsSplitK ? 2 : 1));
|
||||
index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
|
||||
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
|
||||
IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
|
||||
@@ -1204,8 +1217,8 @@ struct GridwiseMoeGemmBlockScale
|
||||
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
|
||||
IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
|
||||
problem.MPadded,
|
||||
problem.N,
|
||||
problem.NPadded,
|
||||
problem.N * (IsInputGemm && IsSplitK ? 2 : 1),
|
||||
problem.NPadded * (IsInputGemm && IsSplitK ? 2 : 1),
|
||||
problem.StrideC);
|
||||
|
||||
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
|
||||
@@ -1215,7 +1228,7 @@ struct GridwiseMoeGemmBlockScale
|
||||
math::integer_divide_ceil(problem.K, ScaleBlockK)),
|
||||
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
|
||||
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
|
||||
make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
|
||||
make_tuple(math::integer_divide_ceil(problem.N * (IsInputGemm && IsSplitK ? 2 : 1), ScaleBlockN),
|
||||
math::integer_divide_ceil(problem.K, ScaleBlockK)),
|
||||
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
|
||||
|
||||
@@ -1371,9 +1384,8 @@ struct GridwiseMoeGemmBlockScale
|
||||
decltype(c_thread_buf) c_thread_buf_up;
|
||||
|
||||
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
|
||||
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
|
||||
KPerBlock);
|
||||
|
||||
problem.KBatch == 1 ?(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
|
||||
KPerBlock : problem.KBatch);
|
||||
constexpr index_t ScaleSliceSizeM = MXdlPerWave;
|
||||
constexpr index_t ScaleSliceSizeN = math::integer_divide_ceil(NPerBlock, ScaleBlockN);
|
||||
constexpr index_t ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK);
|
||||
@@ -1447,7 +1459,7 @@ struct GridwiseMoeGemmBlockScale
|
||||
constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, ScaleSliceSizeK);
|
||||
|
||||
constexpr auto NumKBlockPerScale = math::integer_divide_ceil(ScaleBlockK, KPerBlock);
|
||||
if constexpr(IsInputGemm)
|
||||
if constexpr(IsInputGemm && !IsSplitK)
|
||||
{
|
||||
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
|
||||
const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
@@ -1606,7 +1618,7 @@ struct GridwiseMoeGemmBlockScale
|
||||
blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
|
||||
make_tuple(m0, n0, n2 * N4 + n4));
|
||||
constexpr auto cidx = Number<c_offset>{};
|
||||
if constexpr(IsInputGemm) // gu fusion, elementwise
|
||||
if constexpr(IsInputGemm && !IsSplitK) // gu fusion, elementwise
|
||||
{
|
||||
if constexpr(ActivationOperation == Activation::silu_and_mul)
|
||||
{
|
||||
@@ -1744,7 +1756,7 @@ struct GridwiseMoeGemmBlockScale
|
||||
using EDataType = CDataType;
|
||||
|
||||
const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
|
||||
problem.M, problem.MPadded, problem.N * (IsInputGemm && IsSplitK ? 2 : 1), problem.NPadded * (IsInputGemm && IsSplitK ? 2 : 1), problem.StrideDs);
|
||||
|
||||
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
@@ -1875,7 +1887,7 @@ struct GridwiseMoeGemmBlockScale
|
||||
{
|
||||
token_offset = token_offset * problem.TopK + (fused_token >> 24);
|
||||
}
|
||||
scatter_offsets(m0) = token_offset * problem.N;
|
||||
scatter_offsets(m0) = token_offset * problem.N * (IsInputGemm && IsSplitK ? 2 : 1);
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
@@ -1953,8 +1965,8 @@ struct GridwiseMoeGemmBlockScale
|
||||
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
|
||||
IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
|
||||
problem.MPadded,
|
||||
problem.N,
|
||||
problem.NPadded,
|
||||
problem.N * (IsInputGemm && IsSplitK ? 2 : 1),
|
||||
problem.NPadded * (IsInputGemm && IsSplitK ? 2 : 1),
|
||||
problem.StrideC);
|
||||
|
||||
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
|
||||
@@ -2125,8 +2137,8 @@ struct GridwiseMoeGemmBlockScale
|
||||
decltype(c_thread_buf) c_thread_buf_up;
|
||||
|
||||
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
|
||||
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
|
||||
KPerBlock);
|
||||
problem.KBatch == 1 ?(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
|
||||
KPerBlock : problem.KBatch);
|
||||
|
||||
// scale
|
||||
constexpr index_t ScaleSliceSizeM = MXdlPerWave;
|
||||
@@ -2202,7 +2214,7 @@ struct GridwiseMoeGemmBlockScale
|
||||
constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, ScaleSliceSizeK);
|
||||
|
||||
constexpr auto NumKBlockPerScale = math::integer_divide_ceil(ScaleBlockK, KPerBlock);
|
||||
if constexpr(IsInputGemm)
|
||||
if constexpr(IsInputGemm && !IsSplitK)
|
||||
{
|
||||
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
|
||||
const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
@@ -2352,7 +2364,7 @@ struct GridwiseMoeGemmBlockScale
|
||||
blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
|
||||
make_tuple(m0, n0, n2 * N4 + n4));
|
||||
constexpr auto cidx = Number<c_offset>{};
|
||||
if constexpr(IsInputGemm) // gu fusion, elementwise
|
||||
if constexpr(IsInputGemm && !IsSplitK) // gu fusion, elementwise
|
||||
{
|
||||
if constexpr(ActivationOperation == Activation::silu_and_mul)
|
||||
{
|
||||
@@ -2619,7 +2631,7 @@ struct GridwiseMoeGemmBlockScale
|
||||
{
|
||||
token_offset = token_offset * problem.TopK + (fused_token >> 24);
|
||||
}
|
||||
scatter_offsets(m0) = token_offset * problem.N;
|
||||
scatter_offsets(m0) = token_offset * problem.N * (IsInputGemm && IsSplitK ? 2 : 1);
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
@@ -0,0 +1,232 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace host {
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename CDataType,
|
||||
typename AccDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename ComputeTypeA = AccDataType,
|
||||
typename ComputeTypeB = AccDataType>
|
||||
struct ReferenceMoeGemm1BlockScaleSplitK : public device::BaseOperator
|
||||
{
|
||||
// Argument
|
||||
struct Argument : public device::BaseArgument
|
||||
{
|
||||
Argument(const Tensor<ck::index_t>& sorted_token_ids,
|
||||
const Tensor<ck::index_t>& expert_ids,
|
||||
const Tensor<ck::index_t>& max_token_id,
|
||||
const index_t sorted_tile_size,
|
||||
const Tensor<ADataType>& a_t_k,
|
||||
const Tensor<BDataType>& b_e_n_k,
|
||||
Tensor<CDataType>& c_t_k_n,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
: sorted_token_ids_{sorted_token_ids},
|
||||
expert_ids_{expert_ids},
|
||||
max_token_id_{max_token_id},
|
||||
sorted_tile_size_{sorted_tile_size},
|
||||
a_t_k_{a_t_k},
|
||||
b_e_n_k_{b_e_n_k},
|
||||
c_t_k_n_{c_t_k_n},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
c_element_op_{c_element_op}
|
||||
{
|
||||
}
|
||||
|
||||
const Tensor<ck::index_t>& sorted_token_ids_;
|
||||
const Tensor<ck::index_t>& expert_ids_;
|
||||
const Tensor<ck::index_t>& max_token_id_;
|
||||
index_t sorted_tile_size_;
|
||||
const Tensor<ADataType>& a_t_k_;
|
||||
const Tensor<BDataType>& b_e_n_k_;
|
||||
Tensor<CDataType>& c_t_k_n_;
|
||||
|
||||
AElementwiseOperation a_element_op_;
|
||||
BElementwiseOperation b_element_op_;
|
||||
CElementwiseOperation c_element_op_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public device::BaseInvoker
|
||||
{
|
||||
using Argument = ReferenceMoeGemm1BlockScaleSplitK::Argument;
|
||||
|
||||
float Run(const Argument& arg)
|
||||
{
|
||||
const int full_n = arg.c_t_k_n_.mDesc.GetLengths()[2];
|
||||
auto f_mk_kn_mn = [&](auto m, auto n) {
|
||||
const int K = arg.a_t_k_.mDesc.GetLengths()[1];
|
||||
AccDataType v_acc{0};
|
||||
|
||||
ComputeTypeA v_a{0};
|
||||
ComputeTypeB v_b{0};
|
||||
|
||||
const int t = arg.sorted_token_ids_(m) & 0xffffff;
|
||||
const int topk_id = (arg.sorted_token_ids_(m) & 0xff000000) >> 24;
|
||||
const int e = arg.expert_ids_(m / arg.sorted_tile_size_);
|
||||
const int token_cnt = arg.a_t_k_.mDesc.GetLengths()[0];
|
||||
if(t < token_cnt)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
if constexpr(is_same_v<ADataType, pk_i4_t>)
|
||||
{
|
||||
uint8_t i4x2 = arg.a_t_k_(t, k).data;
|
||||
uint8_t i4 = 0;
|
||||
if(k % 2 == 1)
|
||||
i4 = (i4x2 >> 0) & 0xf;
|
||||
else
|
||||
i4 = (i4x2 >> 4) & 0xf;
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
v_a = i4_to_f32_gfx9(i4);
|
||||
#else
|
||||
v_a = i4 - 8;
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
arg.a_element_op_(v_a, arg.a_t_k_(t, k));
|
||||
}
|
||||
// same for B matrix
|
||||
if constexpr(is_same_v<BDataType, pk_i4_t>)
|
||||
{
|
||||
uint8_t i4x2 = arg.b_e_n_k_(e, k, n).data;
|
||||
uint8_t i4 = 0;
|
||||
if(k % 2 == 1)
|
||||
{
|
||||
i4 = (i4x2 >> 0) & 0xf;
|
||||
}
|
||||
else
|
||||
{
|
||||
i4 = (i4x2 >> 4) & 0xf;
|
||||
}
|
||||
#if CK_USE_PK4_LAYOUT_SHUFFLE
|
||||
v_b = i4_to_f32_gfx9(i4);
|
||||
#else
|
||||
v_b = i4 - 8;
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
arg.b_element_op_(v_b, arg.b_e_n_k_(e, k, n));
|
||||
}
|
||||
|
||||
v_acc +=
|
||||
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
|
||||
}
|
||||
CDataType v_c{0};
|
||||
|
||||
arg.c_element_op_(v_c, v_acc);
|
||||
arg.c_t_k_n_(t, topk_id, n) = v_c;
|
||||
}
|
||||
};
|
||||
|
||||
const ck::index_t max_token_id = arg.max_token_id_(0);
|
||||
make_ParallelTensorFunctor(f_mk_kn_mn, max_token_id, full_n)(
|
||||
std::thread::hardware_concurrency());
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
float Run(const device::BaseArgument* p_arg,
|
||||
const StreamConfig& /* stream_config */ = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
|
||||
|
||||
static auto MakeArgument(const Tensor<ck::index_t>& sorted_token_ids,
|
||||
const Tensor<ck::index_t>& expert_ids,
|
||||
const Tensor<ck::index_t>& max_token_id,
|
||||
const index_t sorted_tile_size,
|
||||
const Tensor<ADataType>& a_t_k,
|
||||
const Tensor<BDataType>& b_e_n_k,
|
||||
Tensor<CDataType>& c_t_k_n,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
return Argument{sorted_token_ids,
|
||||
expert_ids,
|
||||
max_token_id,
|
||||
sorted_tile_size,
|
||||
a_t_k,
|
||||
b_e_n_k,
|
||||
c_t_k_n,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "ReferenceMoeGemm1BlaockScaleSplitK"
|
||||
<< std::endl;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
|
||||
static float i4_to_f32_gfx9(uint8_t i4)
|
||||
{
|
||||
static std::unordered_map<uint8_t, float> u = {{0b1000, -0.5000f},
|
||||
{0b1001, -0.4375f},
|
||||
{0b1010, -0.3750f},
|
||||
{0b1011, -0.3125f},
|
||||
{0b1100, -0.2500f},
|
||||
{0b1101, -0.1875f},
|
||||
{0b1110, -0.1250f},
|
||||
{0b1111, -0.0625f},
|
||||
{0b0, +0.0000f},
|
||||
{0b1, +0.0625f},
|
||||
{0b10, +0.1250f},
|
||||
{0b11, +0.1875f},
|
||||
{0b100, +0.2500f},
|
||||
{0b101, +0.3125f},
|
||||
{0b110, +0.3750f},
|
||||
{0b111, +0.4375f}};
|
||||
|
||||
return u[i4];
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace host
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user