Dev/a8w4 and a8w8splitk (#3447)

* Ck moe bs splitk pr (#3440)

* splitk kick-off. Compilation fail

* splitk hack pass

* fix scale offset calc.

* clang-format for a8w8_moe_blk_gemm1 splitk change

* fix testcase error

---------

Co-authored-by: oscar <huaiguxu@amd.com>
Co-authored-by: huaiguxu <145733371+huaiguxu@users.noreply.github.com>

* Zan/moe a8w4 (#3441)

* update

* update

* update ck moe a8w4

* update

* update

* update

* compile pass

* update

* update

* python3 op_tests/test_moe_2stage.py -t 16 -e 1 -k 1 -dim 256,256 ready

* support new a8w4 kernel

* update

* update ck_tile

* re format

* update

* update

* fix conflict

* fix build

* update ck_tile moe

* fix clang format

* fix the problem

* fix accruacy issue

* fix

---------

Co-authored-by: oscar <huaiguxu@amd.com>
Co-authored-by: huaiguxu <145733371+huaiguxu@users.noreply.github.com>
Co-authored-by: Zzz9990 <zanzhang@amd.com>
Co-authored-by: felix <felix.li@amd.com>

[ROCm/composable_kernel commit: c0ee71d735]
This commit is contained in:
yadaish
2025-12-19 09:26:52 +08:00
committed by GitHub
parent 4693c2c2f1
commit e76ee195df
13 changed files with 2911 additions and 139 deletions

View File

@@ -18,6 +18,7 @@ add_example_executable(example_moe_gemm1_xdl_fp8 moe_gemm1_xdl_fp8.cpp)
add_example_executable(example_moe_gemm2_xdl_fp8 moe_gemm2_xdl_fp8.cpp)
add_example_executable(example_moe_gemm2_xdl_fp8_blockscale moe_gemm2_xdl_fp8_blockscale.cpp)
add_example_executable(example_moe_gemm1_xdl_fp8_blockscale moe_gemm1_xdl_fp8_blockscale.cpp)
add_example_executable(example_moe_gemm1_xdl_fp8_blockscale_splitk moe_gemm1_xdl_fp8_blockscale_splitk.cpp)
list(APPEND gpu_list gfx942 gfx950 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1153 gfx1200 gfx1201 gfx11-generic gfx12-generic)
set(target 0)

View File

@@ -171,7 +171,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceM
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
CShuffleMXDLPerWave, CShuffleNXDLPerWave, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, true, MulRoutedWeight, int32_t, A0DataType>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, true, false, MulRoutedWeight, int32_t, A0DataType>;
#else
static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
Row, Col, DsLayout, ELayout,
@@ -185,7 +185,7 @@ static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
4, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, true, MulRoutedWeight, int32_t, A0DataType>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, true, false, MulRoutedWeight, int32_t, A0DataType>;
#endif
// clang-format on

View File

@@ -0,0 +1,539 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_moe_gemm1_blockscale_splitk.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
using ::ck::DeviceMem;
using ::ck::HostTensorDescriptor;
using ::ck::Tensor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F8 = ck::f8_t;
using F32 = float;
using I64 = int64_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using Bypass = ck::tensor_layout::BypassLayoutVerification;
using A0DataType = F8;
using A1DataType = F32;
using B0DataType = F8;
using B1DataType = F32;
using EDataType = F32;
using AccDataType = F32;
using CShuffleDataType = EDataType;
using D2DataType = F32;
using DsDataType = ck::Tuple<D2DataType>;
using A0Layout = Row;
using B0Layout = Col;
using ELayout = Row;
using D0Layout = Row;
using D1Layout = Col;
using D2Layout = ELayout;
using DsLayout = ck::Tuple<D2Layout>;
struct MulABScaleExpertWeight
{
template <typename E, typename C, typename D2>
__host__ __device__ constexpr void operator()(E& e, const C& c, const D2& d2) const;
// for real kernel use
template <>
__host__ __device__ constexpr void
operator()<EDataType, EDataType, float>(EDataType& e, const EDataType& c, const float& d2) const
{
(void)d2;
e = ck::type_convert<EDataType>(c);
}
};
void preShuffleBuffer(const B0DataType* src, B0DataType* dst, int N, int K, int NXdl)
{
int KPack = 16 / sizeof(B0DataType);
int NLane = NXdl;
int KLane = 64 / NLane;
int K0 = K / (KLane * KPack);
// K -> K0 KLane KPack
// N -> N0 NLane
// N, K -> N0 K0 KLane NLane KPack
int tempk;
for(I64 n = 0; n < N; ++n)
{
for(I64 k = 0; k < K; ++k)
{
I64 n0 = n / NLane;
I64 n1 = n % NLane;
I64 k0 = k / (KLane * KPack);
tempk = k % (KLane * KPack);
I64 k1 = tempk / KPack;
I64 k2 = tempk % KPack;
I64 outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
k1 * KPack * NLane + n1 * KPack + k2;
dst[outputIndex] = src[n * static_cast<I64>(K) + k];
}
}
}
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = MulABScaleExpertWeight;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr ck::index_t Scale_Block_M = 1;
static constexpr ck::index_t Scale_Block_N = 128;
static constexpr ck::index_t Scale_Block_K = 128;
static constexpr ck::index_t Nswizzle = false;
static constexpr ck::index_t IsInputGemm = true; // splitk gemm1 goes to gemm2 pipeline.
static constexpr ck::index_t IsSplitK = true; // splitk gemm1
static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul
static constexpr bool MulRoutedWeight = false; // splitk gemm1 does not do routedWeight.
#if 1
static constexpr ck::index_t MPerBlock = 32;
static constexpr ck::index_t NPerBlock = 128;
static constexpr ck::index_t MNPerXDL = 16;
static constexpr ck::index_t MXDLPerWave = MPerBlock / (MNPerXDL * 1);
static constexpr ck::index_t NXDLPerWave = NPerBlock / (MNPerXDL * 4);
static constexpr ck::index_t CShuffleMXDLPerWave = MXDLPerWave;
static constexpr ck::index_t CShuffleNXDLPerWave = NXDLPerWave;
static constexpr ck::index_t BLOCKSIZE = 256;
static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType);
static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType);
static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType);
static constexpr ck::index_t EVec = 16 / sizeof(EDataType);
static constexpr ck::index_t D0Vec = 1;
static constexpr ck::index_t D1Vec = 1;
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale
// clang-format off
< Row, Col, DsLayout, ELayout,
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
//threadnum, mblock, nblock, kblock
BLOCKSIZE, Scale_Block_M, Scale_Block_N, Scale_Block_K,
MPerBlock, NPerBlock, KPerBlock,
// ak1, bk1
AK1, BK1,
// mn_perxdl
MNPerXDL, MNPerXDL,
// mn_xdlperwave
MXDLPerWave, NXDLPerWave,
// a,b: loadtranfer cluster, cluster order, srcorder,VECDIM, srcpervec, dstpervec, lds_extra
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0,
// CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
CShuffleMXDLPerWave, CShuffleNXDLPerWave, S<1, 32, 1, 8>, S<EVec, D0Vec, D1Vec, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, IsInputGemm, IsSplitK, MulRoutedWeight, int32_t, A0DataType>;
#else
static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
Row, Col, DsLayout, ELayout,
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
256, Scale_Block_M, Scale_Block_N, Scale_Block_K,
MPerBlock, 128, 128,
16, 16,
16, 16,
4, 2,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
4, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, IsInputGemm, IsSplitK, MulRoutedWeight, int32_t, A0DataType>;
#endif
// clang-format on
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
#if 1
// GEMM shape
ck::index_t N = 4096;
ck::index_t K = 6144;
// ck::index_t N = 128;
// ck::index_t K = 512;
ck::index_t experts = 8;
ck::index_t topk = 2;
// ck::index_t sorted_tile_num = 515;
// ck::index_t valid_tile_num = 512;
// ck::index_t tokens = 208;
// ck::index_t sorted_tile_num = 15;
// ck::index_t valid_tile_num = 13;
// ck::index_t sorted_tile_num = 259;
// ck::index_t valid_tile_num = 256;
// ck::index_t tokens = 4096;
ck::index_t sorted_tile_num = 2;
ck::index_t valid_tile_num = 2;
ck::index_t tokens = 32;
#else
// deepseek
ck::index_t N = 2048;
ck::index_t K = 7168;
ck::index_t experts = 256;
ck::index_t topk = 8;
ck::index_t tokens = 4096;
ck::index_t sorted_tile_num = 261;
ck::index_t valid_tile_num = 256;
#endif
ck::index_t KBatch = 6;
if(argc == 1)
{
// use default case
}
else if(argc == 2)
{
KBatch = std::stoi(argv[1]);
}
else if(argc == 4)
{
// use default case
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 7)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
N = std::stoi(argv[4]);
K = std::stoi(argv[5]);
tokens = std::stoi(argv[6]);
}
else if(argc == 9)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
N = std::stoi(argv[4]);
K = std::stoi(argv[5]);
tokens = std::stoi(argv[6]);
sorted_tile_num = std::stoi(argv[7]);
valid_tile_num = std::stoi(argv[8]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 6: N, K, tokens\n");
exit(0);
}
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;
if(tokens * topk > valid_size)
{
printf("err config, tokens * topk > valid_size\n");
exit(-1);
}
ck::index_t StrideA = K;
ck::index_t StrideB = K;
ck::index_t StrideE = N * 2;
constexpr ck::index_t NumDTensor = DsDataType::Size();
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0};
ck::index_t Scale_Stride_AM = (K + Scale_Block_K - 1) / Scale_Block_K;
ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K;
ck::index_t Scale_Stride_B = (N + Scale_Block_N - 1) / Scale_Block_N * 2;
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1 + sorted_tile_num}));
max_token_id.mData = {valid_size};
// int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3};
for(int i = 0; i < sorted_tile_num; i++)
{
expert_ids.mData[i] = i / ck::math::integer_divide_ceil(valid_tile_num, experts);
}
int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num;
int tokenid = 0;
for(int i = 0; i < sorted_size; i++)
{
int tile_off = i % MPerBlock;
if(tile_off < token_per_tile && tokenid < tokens * topk)
{
sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24);
tokenid++;
}
else
{
sorted_token_ids.mData[i] = tokens;
}
}
Tensor<A0DataType> a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1}));
Tensor<A1DataType> a1_t_k(HostTensorDescriptor(
{tokens, (K + Scale_Block_K - 1) / Scale_Block_K}, {Scale_Stride_AM, 1}, Row{}));
Tensor<B0DataType> b0_e_n_k(
HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{}));
Tensor<B1DataType> b1_e_n_k(
HostTensorDescriptor({experts,
(K + Scale_Block_K - 1) / Scale_Block_K,
(N + Scale_Block_N - 1) / Scale_Block_N * 2},
{(Scale_Stride_B * Scale_Stride_BN), 1, Scale_Stride_BN},
Col{}));
Tensor<B0DataType> b0_preshuffled(
HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{}));
Tensor<EDataType> e_t_n_host_result(
HostTensorDescriptor({tokens, topk, N * 2}, {topk * N * 2, N * 2, 1}, Row{}));
Tensor<EDataType> e_t_n_device_result(
HostTensorDescriptor({tokens, topk, N * 2}, {topk * N * 2, N * 2, 1}, Row{}));
e_t_n_device_result.SetZero();
std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl;
std::cout << "a1_t_k: " << a1_t_k.mDesc << std::endl;
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl;
std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl;
std::cout << "k_batch:" << KBatch << std::endl;
std::cout << "init_method:" << init_method << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-1.0, 1.0});
a1_t_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-1.0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0.0, 1.0});
break;
case 2:
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
a1_t_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
case 3:
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
a1_t_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
break;
case 4:
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5});
a1_t_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
break;
case 5:
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5});
a1_t_k.GenerateTensorValue(GeneratorTensor_1<A1DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
break;
case 6:
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5});
a1_t_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<B1DataType>{});
break;
default:
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{-0.5, 0.5});
a1_t_k.GenerateTensorValue(GeneratorTensor_3<A1DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<B1DataType>{0, 1.0});
}
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) *
sorted_token_ids.mDesc.GetElementSpaceSize());
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize());
DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize());
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.mDesc.GetElementSpaceSize());
DeviceMem a1_device_buf(sizeof(A1DataType) * a1_t_k.mDesc.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize());
DeviceMem b1_device_buf(sizeof(B1DataType) * b1_e_n_k.mDesc.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize());
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
expert_ids_dev.ToDevice(expert_ids.mData.data());
max_token_id_dev.ToDevice(max_token_id.mData.data());
a0_device_buf.ToDevice(a0_t_k.mData.data());
a1_device_buf.ToDevice(a1_t_k.mData.data());
b1_device_buf.ToDevice(b1_e_n_k.mData.data());
e_device_buf.ToDevice(e_t_n_device_result.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
// do GEMM
auto device_op = DeviceOpInstance{};
int NPerXdl = device_op.GetPreShuffleParameters();
preShuffleBuffer(
b0_e_n_k.mData.data(), b0_preshuffled.mData.data(), N * 2 * experts, K, NPerXdl);
b0_device_buf.ToDevice(b0_preshuffled.mData.data());
auto invoker = device_op.MakeInvoker();
auto argument = device_op.MakeArgument(sorted_token_ids_dev.GetDeviceBuffer(),
expert_ids_dev.GetDeviceBuffer(),
max_token_id_dev.GetDeviceBuffer(),
a0_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(),
std::array<const void*, NumDTensor>{nullptr},
e_device_buf.GetDeviceBuffer(),
tokens,
topk,
sorted_size,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideE,
a1_device_buf.GetDeviceBuffer(),
b1_device_buf.GetDeviceBuffer(),
KBatch,
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
if(time_kernel)
{
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop = std::size_t(2) * tokens * topk * N * 2 * K;
std::size_t num_btype = sizeof(A0DataType) * valid_tile_num * K +
sizeof(B0DataType) * K * N * 2 * experts +
sizeof(EDataType) * valid_tile_num * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s.\n"
<< device_op.GetTypeString() << std::endl;
}
if(do_verification)
{
// use atomic, so need to reinit outputs
e_device_buf.ToDevice(e_t_n_device_result.mData.data());
invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1});
Tensor<float> a_t_k({tokens, K});
Tensor<float> b_e_n_k({experts, K, N * 2});
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
Tensor<float> c_t_k_n({tokens, topk, N * 2}, {topk * N * 2, N * 2, 1}, Row{});
// handle scale before ref.
for(int t = 0; t < tokens; ++t)
{
for(int k = 0; k < K; ++k)
{
a_t_k(t, k) = ck::type_convert<float>(a0_t_k(t, k)) * a1_t_k(t, k / Scale_Block_K);
}
}
for(int e = 0; e < experts; ++e)
{
for(int k = 0; k < K; ++k)
{
for(int n = 0; n < N * 2; ++n)
{
b_e_n_k(e, k, n) = ck::type_convert<float>(b0_e_n_k(e, k, n)) *
b1_e_n_k(e, k / Scale_Block_K, n / Scale_Block_N);
}
}
}
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceMoeGemm1BlockScaleSplitK<float,
float,
float,
AccDataType,
PassThrough,
PassThrough,
PassThrough>;
auto ref_moe_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_moe_gemm.MakeInvoker();
auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids,
expert_ids,
max_token_id,
MPerBlock,
a_t_k,
b_e_n_k,
c_t_k_n,
PassThrough{},
PassThrough{},
PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < valid_size; ++m)
{
const int fuse_t = sorted_token_ids.mData[m];
const int t = fuse_t & 0xffffff;
const int topk_id = (fuse_t & 0xff000000) >> 24;
if(t >= tokens)
{
continue;
}
for(int n = 0; n < 2 * N; ++n)
{
e_t_n_host_result(t, topk_id, n) =
ck::type_convert<EDataType>(c_t_k_n(t, topk_id, n));
}
}
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
auto status =
ck::utils::check_err(
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1)
? 0
: 1;
if(status == 0)
{
printf("Validation Pass.\n");
}
return status;
}
return 0;
}

View File

@@ -165,7 +165,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0,
2, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, int32_t, A0DataType>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, false, MulRoutedWeight, int32_t, A0DataType>;
#else
static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale<
@@ -180,7 +180,7 @@ static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
2, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, int32_t, A0DataType>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, false, MulRoutedWeight, int32_t, A0DataType>;
#endif
// clang-format on

View File

@@ -360,6 +360,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1<
});
});
__builtin_amdgcn_sched_barrier(0);
// Local prefill A1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
@@ -550,6 +551,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1<
});
});
__builtin_amdgcn_sched_barrier(0);
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc,
@@ -677,6 +679,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1<
});
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {

View File

@@ -74,6 +74,7 @@ template <typename ALayout,
index_t ActivationOP = 0,
bool NSwizzle = false,
bool IsInputGemm = true,
bool IsSplitK = false,
bool MulRoutedWeight = false,
typename IndexType = index_t,
typename ComputeTypeA = CDataType,
@@ -156,6 +157,7 @@ struct DeviceMoeGemmBlockScale
ActivationOP,
NSwizzle,
IsInputGemm,
IsSplitK,
MulRoutedWeight,
IndexType,
ComputeTypeA,
@@ -201,12 +203,12 @@ struct DeviceMoeGemmBlockScale
}
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(
arg.M, arg.N * (IsInputGemm && IsSplitK ? 2 : 1), arg.K, arg.KBatch);
float ave_time = 0;
index_t k_grain = arg.KBatch * KPerBlock;
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
index_t K_split = arg.KBatch == 1 ? arg.K : arg.KBatch * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
const auto RunKernel = [&](const auto& kernel) {
@@ -249,11 +251,12 @@ struct DeviceMoeGemmBlockScale
// rotating mem
rotating_mem.Next();
// clear c mem
if(arg_.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
0,
arg_.M * arg_.N * sizeof(CDataType),
stream_config.stream_id_));
// if(arg_.KBatch > 1)
// hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
// 0,
// arg_.M * arg_.N * sizeof(CDataType)
// * (IsInputGemm && IsSplitK ? 2 : 1),
// stream_config.stream_id_));
};
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
@@ -267,11 +270,12 @@ struct DeviceMoeGemmBlockScale
}
else
{
if(arg.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
0,
arg.M * arg.N * sizeof(CDataType),
stream_config.stream_id_));
// if(arg.KBatch > 1)
// hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
// 0,
// arg.M * arg.N * sizeof(CDataType) *
// (IsInputGemm && IsSplitK ? 2 : 1),
// stream_config.stream_id_));
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
@@ -289,8 +293,9 @@ struct DeviceMoeGemmBlockScale
constexpr index_t minimum_occupancy = (estimated_reg_total >= 256) ? 1 : 2;
constexpr auto MemoryDataOp =
IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd;
constexpr auto MemoryDataOp = (IsInputGemm && !IsSplitK)
? InMemoryDataOperationEnum::Set
: InMemoryDataOperationEnum::AtomicAdd;
if(has_main_k_block_loop)
{
@@ -416,8 +421,8 @@ struct DeviceMoeGemmBlockScale
static bool IsSupportedArgument(const Argument& arg)
{
// only impl kbatch 1 now
if(arg.KBatch > 1)
// only impl kbatch 1 for fp32
if(arg.KBatch > 1 && !std::is_same_v<CDataType, float>)
{
return false;
}
@@ -441,6 +446,11 @@ struct DeviceMoeGemmBlockScale
{
return false;
}
if(arg.KBatch > 1 && arg.K % (KPerBlock * arg.KBatch) != 0)
{
// Not support Kpadding with KBatch > 1
return false;
}
if(get_warp_size() == 64)
{

View File

@@ -60,8 +60,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid,
karg.p_c_grid,
karg.p_a_scale_grid,
karg.p_b_scale_grid,
karg.p_a_scale_grid + splitk_batch_offset.ascale_k_split_offset,
karg.p_b_scale_grid + splitk_batch_offset.bscale_k_split_offset,
p_shared,
karg,
karg.a_element_op,
@@ -101,8 +101,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid,
karg.p_c_grid,
karg.p_a_scale_grid,
karg.p_b_scale_grid,
karg.p_a_scale_grid + splitk_batch_offset.ascale_k_split_offset,
karg.p_b_scale_grid + splitk_batch_offset.bscale_k_split_offset,
p_shared,
p_shared1,
karg,
@@ -167,6 +167,7 @@ template <typename ALayout,
index_t ActivationOperation = 0,
bool NSwizzle = false,
bool IsInputGemm = true,
bool IsSplitK = false,
bool MulRoutedWeight = true,
typename IndexType = index_t,
typename ComputeTypeA = CDataType,
@@ -249,13 +250,15 @@ struct GridwiseMoeGemmBlockScale
return 1;
}();
__host__ static auto CalculateGridSize(index_t M, index_t N)
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t K, index_t KBatch)
{
const index_t nblock = math::integer_divide_ceil(N, NPerBlock);
const index_t mblock = math::integer_divide_ceil(M, MPerBlock);
const index_t gridx = NSwizzle ? nblock * mblock : nblock;
const index_t gridy = NSwizzle ? 1 : mblock;
return std::make_tuple(gridx, gridy, 1);
const index_t gridz = KBatch == 1 ? 1 : math::integer_divide_ceil(K, KPerBlock * KBatch);
return std::make_tuple(gridx, gridy, gridz);
}
__host__ __device__ static auto CalculateMPadded(index_t M)
@@ -284,27 +287,32 @@ struct GridwiseMoeGemmBlockScale
__host__ __device__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
{
auto K_t = K_Batch * KPerBlock;
return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
// auto K_t = K_Batch * KPerBlock;
// return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
return K_Batch == 1 ? K / AK1Value : K_Batch * KPerBlock / AK1Value;
}
__host__ __device__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
{
auto K_t = K_Batch * KPerBlock;
return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
// auto K_t = K_Batch * KPerBlock;
// return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
return K_Batch == 1 ? K / BK1Value : K_Batch * KPerBlock / BK1Value;
}
__host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
{
auto K_t = K_Batch * KPerBlock;
return (K + K_t - 1) / K_t * KPerBlock;
// auto K_t = K_Batch * KPerBlock;
// return (K + K_t - 1) / K_t * KPerBlock;
return K_Batch == 1 ? K : K_Batch * KPerBlock;
}
__host__ __device__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
{
constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
auto K_t = K_Batch * KReadVec;
return (K + K_t - 1) / K_t * KReadVec;
// auto K_t = K_Batch * KReadVec;
// return (K + K_t - 1) / K_t * KReadVec;
return K_Batch == 1 ? math::integer_divide_ceil(K, KReadVec) * KReadVec
: K_Batch * KPerBlock;
}
__host__ __device__ static auto CalculateMBlock(index_t M)
@@ -409,7 +417,6 @@ struct GridwiseMoeGemmBlockScale
make_pass_through_transform(M)),
make_tuple(Sequence<1>{}, Sequence<0>{}),
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
return a_grid_desc_ak0_m_ak1;
}
}
@@ -741,35 +748,41 @@ struct GridwiseMoeGemmBlockScale
{
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
a_k_split_offset = k_id * karg.KRead / APackedSize;
a_k_split_offset = k_id * karg.KRead / APackedSize;
ascale_k_split_offset = math::integer_divide_floor(a_k_split_offset, ScaleBlockK);
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
a_k_split_offset = k_id * karg.KRead * karg.StrideA;
a_k_split_offset = k_id * karg.KRead * karg.StrideA;
ascale_k_split_offset = math::integer_divide_floor(a_k_split_offset, ScaleBlockK);
}
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_k_split_offset = k_id * karg.KRead * karg.StrideB;
b_k_split_offset = k_id * karg.KRead * karg.StrideB;
bscale_k_split_offset = math::integer_divide_floor(b_k_split_offset, ScaleBlockK);
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
// KPack * NLane * KLane * K0 * N0
b_k_split_offset = k_id * karg.KRead * NLane / BPackedSize;
b_k_split_offset = k_id * karg.KRead * NLane / BPackedSize;
bscale_k_split_offset = k_id * karg.KRead / ScaleBlockK;
}
if(k_id < karg.KBatch - 1)
{
karg.K = karg.KRead;
}
else
{
karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
}
// if(k_id < karg.KBatch - 1)
// {
// karg.K = karg.KRead;
// }
// else
// {
// karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
// }
}
index_t a_k_split_offset;
index_t b_k_split_offset;
index_t ascale_k_split_offset;
index_t bscale_k_split_offset;
};
__device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
@@ -912,8 +925,8 @@ struct GridwiseMoeGemmBlockScale
}
using BlockwiseGemmPipe =
remove_cvref_t<decltype(BlockGemmBlockMoeScaleBPreshufflePipeline_Selector<
BlkGemmPipelineVer,
remove_cvref_t<decltype(BlockGemmBlockMoeScaleBPreshufflePipeline_Selector <
BlkGemmPipelineVer,
BlkGemmPipeSched,
BlockSize,
ADataType,
@@ -939,7 +952,7 @@ struct GridwiseMoeGemmBlockScale
MXdlPerWave,
NXdlPerWave,
KPack,
IsInputGemm>())>;
IsInputGemm && !IsSplitK > ())>;
__device__ static constexpr index_t GetSharedMemoryNumberOfByte()
{
@@ -1189,9 +1202,9 @@ struct GridwiseMoeGemmBlockScale
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
ignore = b_element_op;
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
ignore = b_element_op;
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N * (IsInputGemm && IsSplitK ? 2 : 1));
index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
problem.MPadded,
@@ -1204,8 +1217,8 @@ struct GridwiseMoeGemmBlockScale
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
problem.MPadded,
problem.N,
problem.NPadded,
problem.N * (IsInputGemm && IsSplitK ? 2 : 1),
problem.NPadded * (IsInputGemm && IsSplitK ? 2 : 1),
problem.StrideC);
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
@@ -1215,7 +1228,8 @@ struct GridwiseMoeGemmBlockScale
math::integer_divide_ceil(problem.K, ScaleBlockK)),
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
make_tuple(math::integer_divide_ceil(problem.N * (IsInputGemm && IsSplitK ? 2 : 1),
ScaleBlockN),
math::integer_divide_ceil(problem.K, ScaleBlockK)),
make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1));
@@ -1371,9 +1385,10 @@ struct GridwiseMoeGemmBlockScale
decltype(c_thread_buf) c_thread_buf_up;
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
problem.KBatch == 1
? (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock
: problem.KBatch);
constexpr index_t ScaleSliceSizeM = MXdlPerWave;
constexpr index_t ScaleSliceSizeN = math::integer_divide_ceil(NPerBlock, ScaleBlockN);
constexpr index_t ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK);
@@ -1447,7 +1462,7 @@ struct GridwiseMoeGemmBlockScale
constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, ScaleSliceSizeK);
constexpr auto NumKBlockPerScale = math::integer_divide_ceil(ScaleBlockK, KPerBlock);
if constexpr(IsInputGemm)
if constexpr(IsInputGemm && !IsSplitK)
{
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
@@ -1606,7 +1621,7 @@ struct GridwiseMoeGemmBlockScale
blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
make_tuple(m0, n0, n2 * N4 + n4));
constexpr auto cidx = Number<c_offset>{};
if constexpr(IsInputGemm) // gu fusion, elementwise
if constexpr(IsInputGemm && !IsSplitK) // gu fusion, elementwise
{
if constexpr(ActivationOperation == Activation::silu_and_mul)
{
@@ -1743,8 +1758,12 @@ struct GridwiseMoeGemmBlockScale
using EDataType = CDataType;
const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
const auto ds_grid_desc_m_n =
MakeDsGridDescriptor_M_N(problem.M,
problem.MPadded,
problem.N * (IsInputGemm && IsSplitK ? 2 : 1),
problem.NPadded * (IsInputGemm && IsSplitK ? 2 : 1),
problem.StrideDs);
const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
@@ -1875,7 +1894,8 @@ struct GridwiseMoeGemmBlockScale
{
token_offset = token_offset * problem.TopK + (fused_token >> 24);
}
scatter_offsets(m0) = token_offset * problem.N;
scatter_offsets(m0) =
token_offset * problem.N * (IsInputGemm && IsSplitK ? 2 : 1);
});
block_sync_lds();
@@ -1953,8 +1973,8 @@ struct GridwiseMoeGemmBlockScale
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
problem.MPadded,
problem.N,
problem.NPadded,
problem.N * (IsInputGemm && IsSplitK ? 2 : 1),
problem.NPadded * (IsInputGemm && IsSplitK ? 2 : 1),
problem.StrideC);
const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor(
@@ -2125,8 +2145,10 @@ struct GridwiseMoeGemmBlockScale
decltype(c_thread_buf) c_thread_buf_up;
const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
(a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock);
problem.KBatch == 1
? (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
KPerBlock
: problem.KBatch);
// scale
constexpr index_t ScaleSliceSizeM = MXdlPerWave;
@@ -2202,7 +2224,7 @@ struct GridwiseMoeGemmBlockScale
constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, ScaleSliceSizeK);
constexpr auto NumKBlockPerScale = math::integer_divide_ceil(ScaleBlockK, KPerBlock);
if constexpr(IsInputGemm)
if constexpr(IsInputGemm && !IsSplitK)
{
const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
@@ -2352,7 +2374,7 @@ struct GridwiseMoeGemmBlockScale
blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
make_tuple(m0, n0, n2 * N4 + n4));
constexpr auto cidx = Number<c_offset>{};
if constexpr(IsInputGemm) // gu fusion, elementwise
if constexpr(IsInputGemm && !IsSplitK) // gu fusion, elementwise
{
if constexpr(ActivationOperation == Activation::silu_and_mul)
{
@@ -2619,7 +2641,8 @@ struct GridwiseMoeGemmBlockScale
{
token_offset = token_offset * problem.TopK + (fused_token >> 24);
}
scatter_offsets(m0) = token_offset * problem.N;
scatter_offsets(m0) =
token_offset * problem.N * (IsInputGemm && IsSplitK ? 2 : 1);
});
block_sync_lds();

View File

@@ -218,6 +218,44 @@ struct tile_scatter_gather
pre_computed_coords_(iCoord) =
make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
});
if constexpr(BottomTensorView::buffer_view::get_address_space() ==
address_space_enum::global)
{
auto partition_index = get_partition_index(tile_distribution);
auto use_lane_id_0 = partition_index;
use_lane_id_0[1] = 0;
const auto window_adaptor_thread_coord_tmp_warp = make_tensor_adaptor_coordinate(
tile_distribution.get_ps_ys_to_xs_adaptor(),
container_concat(use_lane_id_0, array<index_t, NDimY>{0}));
BottomTensorIndex bottom_tensor_thread_origin_idx_tmp_warp =
window_origin + window_adaptor_thread_coord_tmp_warp.get_bottom_index();
bottom_tensor_thread_origin_idx_tmp_warp(HsGatherDim) = 0;
const auto bottom_tensor_thread_coord_tmp_warp =
make_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
bottom_tensor_thread_origin_idx_tmp_warp);
// pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
// future load/store() calls (might allocate more registers)
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp_warp;
auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp_warp;
constexpr auto idx_diff_ys =
SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
idx_diff_ys);
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
pre_computed_warp_coords_(iCoord) =
make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
});
}
}
CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; }
@@ -602,6 +640,135 @@ struct tile_scatter_gather
});
}
// TODO: fix with swizzle
template <typename LdsTileWindow_,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true,
bool static_move_ys = false,
typename = std::enable_if_t<std::is_class_v<remove_cvref_t<LdsTileWindow_>>>>
CK_TILE_DEVICE void async_load_with_offset(index_t offset,
LdsTileWindow_&& lds_tile,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {},
bool_constant<static_move_ys> = {}) const
{
using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
using LdsDataType = typename LdsTileWindow::DataType;
using Traits = load_store_traits;
using vector_t = typename Traits::vector_t;
using SFC_Ys = typename Traits::SFC_Ys;
// Precompute invariant values outside loops
const auto window_origin = lds_tile.get_window_origin();
const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view();
const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor();
auto lds_base_ptr = bottom_tensor_view.get_buffer_view().p_data_;
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
auto window_adaptor_warp_coord = pre_computed_warp_coords_[iCoord][I0];
auto bottom_tensor_warp_coord = pre_computed_warp_coords_[iCoord][I1];
static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
constexpr auto idx_ys_offset = [&]() {
constexpr auto idx_off_ys = SFC_Ys::get_step_between(number<0>{}, iAccess);
constexpr auto adapter_ys_offset = make_tensor_adaptor_coordinate(
StaticTileDistribution_{}.get_ps_ys_to_xs_adaptor(),
container_concat(array<index_t, NDimP>{0},
to_array<index_t, idx_off_ys.size()>(idx_off_ys)));
return adapter_ys_offset.get_bottom_index();
}();
const auto lds_ys_offset = [&]() {
if constexpr(static_move_ys)
{
const auto coord_ys_offset =
make_tensor_coordinate(tensor_descriptor, idx_ys_offset);
return coord_ys_offset.get_offset();
}
else
return 0;
}();
// Use precomputed window origin & tensor descriptor
auto lds_bottom_tensor_thread_idx =
window_origin + window_adaptor_warp_coord.get_bottom_index();
const auto lds_coord =
make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx);
// Calculate SMEM address using base pointer
CK_TILE_LDS_ADDR LdsDataType* smem = lds_base_ptr +
lds_coord.get_offset() / Traits::PackedSize +
lds_ys_offset / Traits::PackedSize;
const auto dram_ys_offset = [&]() {
if constexpr(static_move_ys)
{
const auto coord_ys_offset = make_tensor_coordinate(
this->get_bottom_tensor_view().get_tensor_descriptor(), idx_ys_offset);
return coord_ys_offset.get_offset();
}
else
return 0;
}();
constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
const auto page_offset = page_idx_[idx_gather];
auto mixed_bottom_thread_coord = bottom_tensor_thread_coord;
mixed_bottom_thread_coord.get_hidden_index()[number<0>{}] += page_offset;
if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
{
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
smem,
mixed_bottom_thread_coord,
offset + dram_ys_offset,
bool_constant<oob_conditional_check>{});
}
else
{
this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
smem,
mixed_bottom_thread_coord,
offset + dram_ys_offset,
valids_[idx_gather],
bool_constant<oob_conditional_check>{});
}
// Move thread coordinate if not last access
if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{
constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
constexpr auto forward_step_scatter = generate_tuple(
[&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
number<NDimY>{});
constexpr auto idx_diff_ps_ys = container_concat(
generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
forward_step_scatter);
if constexpr(!static_move_ys)
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_thread_coord,
bottom_tensor_thread_coord,
idx_diff_ps_ys);
if constexpr(!static_move_ys)
move_window_adaptor_and_bottom_tensor_thread_coordinate(
window_adaptor_warp_coord, bottom_tensor_warp_coord, idx_diff_ps_ys);
}
});
});
}
template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
CK_TILE_DEVICE void update(const static_distributed_tensor<DataType, TileDstr>& dstr_tensor,
number<i_access_unsupport_> = {},
@@ -788,6 +955,15 @@ struct tile_scatter_gather
pre_computed_coords_(iCoord)(I1),
step_new);
});
if constexpr(BottomTensorView::buffer_view::get_address_space() ==
address_space_enum::global)
{
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
pre_computed_warp_coords_(iCoord)(I1),
step_new);
});
}
}
CK_TILE_DEVICE void update_page_idx(const PageIdxArray& new_idx) { page_idx_ = new_idx; }
@@ -892,6 +1068,11 @@ struct tile_scatter_gather
// per-thread coordinate for window adaptor
// per-thread coordinate for bottom tensor
array<tuple<WindowAdaptorCoord, BottomTensorCoord>, NumCoord> pre_computed_coords_;
std::conditional_t<BottomTensorView::buffer_view::get_address_space() ==
address_space_enum::global,
array<tuple<WindowAdaptorCoord, BottomTensorCoord>, NumCoord>,
std::byte>
pre_computed_warp_coords_;
};
// TODO: use strategy
@@ -906,7 +1087,7 @@ make_tile_scatter_gather(const TensorView_& tensor_view,
const WindowLengths_& window_lengths,
const multi_index<TensorView_::get_num_of_dimension()>& origin,
const StaticTileDistribution_& tile_distribution,
const StaticPageIndexArray_& page_idx,
const StaticPageIndexArray_& page_idx, // perbytes
number<HsGatherDim> = {},
number<NumCoord> = {})
{

View File

@@ -217,6 +217,7 @@ struct MoeFlatmmKernel
static constexpr auto I1 = number<1>();
static constexpr auto I2 = number<2>();
static constexpr auto I3 = number<3>();
static constexpr auto I4 = number<4>();
static_assert(DsLayout::size() == DsDataType::size(),
"The size of DsLayout and DsDataType should be the same");
@@ -241,12 +242,24 @@ struct MoeFlatmmKernel
IsGateUp ? TilePartitioner::NPerBlock / 2 : TilePartitioner::NPerBlock;
// MXF4_Pipeline only has the of scale B and granularityK is 32
static constexpr bool MXFP4_Pipeline = std::is_same_v<BDataType, pk_fp4_t>;
static constexpr int MXFP4N_Pack = 2;
static constexpr int MXFP4K_Pack = 2;
static constexpr bool AQUANT_Pipeline = std::is_same_v<ADataType, bf8_t> ||
std::is_same_v<ADataType, fp8_t> ||
std::is_same_v<ADataType, pk_fp4_t>;
static constexpr bool BMXFP4_Pipeline = std::is_same_v<BDataType, pk_fp4_t>;
static constexpr int N_Pack = MXFP4_Pipeline ? MXFP4N_Pack : 1;
static constexpr int K_Pack = MXFP4_Pipeline ? MXFP4K_Pack : 1;
static constexpr bool MXF8F6F4MFMA =
#ifdef __gfx950__
AQUANT_Pipeline && BMXFP4_Pipeline;
#else
false;
#endif
static constexpr int MXFP4M_Pack = 2;
static constexpr int MXFP4N_Pack = 2;
static constexpr int MXFP4K_Pack = 2;
static constexpr int M_Pack = AQUANT_Pipeline ? MXFP4M_Pack : 1;
static constexpr int N_Pack = BMXFP4_Pipeline ? MXFP4N_Pack : 1;
static constexpr int K_Pack = BMXFP4_Pipeline ? MXFP4K_Pack : 1;
static constexpr int WeightPackedSize = numeric_traits<BDataType>::PackedSize;
@@ -659,23 +672,95 @@ struct MoeFlatmmKernel
}
}();
auto scale_n = kargs.scale_n;
constexpr int GranularityK = decltype(scale_n)::GranularityK;
const auto& scale_a_tensor_view = [&]() {
auto scale_m_desc = kargs.scale_m;
if constexpr(AQUANT_Pipeline)
{
constexpr int AGranularityK = decltype(scale_m_desc)::GranularityK == 0
? 1
: decltype(scale_m_desc)::GranularityK;
index_t scale_k = GranularityK == 0 ? 1 : (kargs.K + GranularityK - 1) / GranularityK;
index_t FlatScaleK = scale_k * N_Pack * BlockGemmShape::WarpTile::at(I1);
index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1);
constexpr int MThreadPerXdl = BlockGemmShape::WarpTile::at(I0);
constexpr int KThreadPerXdl = 64 / BlockGemmShape::WarpTile::at(I0);
index_t scale_m_packs = kargs.M / (MXFP4M_Pack * MThreadPerXdl);
index_t scale_k_packs = kargs.K / (MXFP4K_Pack * AGranularityK * KThreadPerXdl);
// Pack 2x2 e8m0 over M/K dimension into 1 int32_t to trigger dword width load
const auto scale_a_naive_desc = make_naive_tensor_descriptor_packed(
make_tuple(scale_m_packs, scale_k_packs, KThreadPerXdl, MThreadPerXdl));
const auto scale_a_desc = transform_tensor_descriptor(
scale_a_naive_desc,
make_tuple(make_merge_transform(make_tuple(scale_m_packs, MThreadPerXdl)),
make_merge_transform(make_tuple(scale_k_packs, KThreadPerXdl))),
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return make_tensor_view<address_space_enum::global>(
reinterpret_cast<const int32_t*>(scale_m_desc.ptr), scale_a_desc);
}
else
{
constexpr int AGranularityK = 32;
constexpr int MThreadPerXdl = BlockGemmShape::WarpTile::at(I0);
constexpr int KThreadPerXdl = 64 / BlockGemmShape::WarpTile::at(I0);
index_t scale_m_packs = kargs.M / (MXFP4M_Pack * MThreadPerXdl);
index_t scale_k_packs = kargs.K / (MXFP4K_Pack * AGranularityK * KThreadPerXdl);
return make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const int32_t*>(scale_m_desc.ptr),
make_tuple(scale_m_packs * MThreadPerXdl, scale_k_packs * KThreadPerXdl),
make_tuple(scale_k_packs * KThreadPerXdl, 1),
number<8>{},
number<1>{});
}
}();
using ScaleType = std::conditional_t<MXFP4_Pipeline, e8m0_t, float>;
const auto scale_b_flat_view = [&]() {
auto scale_n = kargs.scale_n;
constexpr int BGranularityK =
decltype(scale_n)::GranularityK == 0 ? 1 : decltype(scale_n)::GranularityK;
if constexpr(AQUANT_Pipeline)
{
index_t scale_k =
BGranularityK == 0 ? 1 : (kargs.K + BGranularityK - 1) / BGranularityK;
constexpr int NThreadPerXdl = BlockGemmShape::WarpTile::at(I1);
constexpr int KThreadPerXdl = 64 / BlockGemmShape::WarpTile::at(I1);
index_t scale_n_packs = kargs.N / (MXFP4N_Pack * NThreadPerXdl);
index_t scale_k_packs = kargs.K / (MXFP4K_Pack * BGranularityK * KThreadPerXdl);
const auto scale_b_navie_desc = make_naive_tensor_descriptor_packed(
make_tuple(scale_n_packs, scale_k_packs, KThreadPerXdl, NThreadPerXdl));
const auto scale_b_desc = transform_tensor_descriptor(
scale_b_navie_desc,
make_tuple(make_merge_transform(make_tuple(scale_n_packs, NThreadPerXdl)),
make_merge_transform(make_tuple(scale_k_packs, KThreadPerXdl))),
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
const auto scale_b_flat_view = make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const ScaleType*>(scale_n.ptr) + expert_id * kargs.N * scale_k,
make_tuple(FlatScaleN - kargs.n_padded_zeros / NPerXdl / N_Pack, FlatScaleK),
make_tuple(FlatScaleK, 1),
number<8>{},
number<1>{});
return make_tensor_view<address_space_enum::global>(
reinterpret_cast<const int32_t*>(scale_n.ptr) +
expert_id * kargs.N * scale_k / 4,
scale_b_desc);
}
else
{
index_t scale_k =
BGranularityK == 0 ? 1 : (kargs.K + BGranularityK - 1) / BGranularityK;
index_t FlatScaleK = scale_k * N_Pack * BlockGemmShape::WarpTile::at(I1);
index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1);
return make_tuple(a_tensor_view, b_flat_tensor_view, c_tensor_view, scale_b_flat_view);
using ScaleType = e8m0_t;
return make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const ScaleType*>(scale_n.ptr) + expert_id * kargs.N * scale_k,
make_tuple(FlatScaleN - kargs.n_padded_zeros / NPerXdl / N_Pack, FlatScaleK),
make_tuple(FlatScaleK, 1),
number<8>{},
number<1>{});
}
}();
return make_tuple(a_tensor_view,
b_flat_tensor_view,
c_tensor_view,
scale_a_tensor_view,
scale_b_flat_view);
}
template <typename TensorView>
@@ -718,7 +803,7 @@ struct MoeFlatmmKernel
}
}();
return make_tuple(a_pad_view, views.at(I1), c_pad_view, views.at(I3));
return make_tuple(a_pad_view, views.at(I1), c_pad_view, views.at(I3), views.at(I4));
}
template <typename PadView>
@@ -747,7 +832,7 @@ struct MoeFlatmmKernel
}
}();
constexpr bool isNonInterleaveGateUp = !IsGateUp || MXFP4_Pipeline;
constexpr bool isNonInterleaveGateUp = !IsGateUp || BMXFP4_Pipeline;
const auto& b_flat_block_window =
make_tile_window(b_flat_pad_view,
@@ -766,17 +851,40 @@ struct MoeFlatmmKernel
output_N_offset});
constexpr int GranularityK = 32; // fixed config for MXF4_Pipeline
auto a_scale_block_window = make_tile_window(
views.at(I3),
make_tuple(number<TilePartitioner::MPerBlock / M_Pack>{},
number<TilePartitioner::KPerBlock / (GranularityK * K_Pack)>{}),
{coord_m / M_Pack, 0});
constexpr int XDLPerLoadScaleB =
MXFP4_Pipeline ? 4 : 1; // GranularityK32 / XDL16x16x32_K8 = 4
BMXFP4_Pipeline ? 4 : 1; // GranularityK32 / XDL16x16x32_K8 = 4
auto scale_block_window =
make_tile_window(views.at(I3),
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
number<FlatmmPipeline::flatKPerWarp * N_Pack * K_Pack *
XDLPerLoadScaleB / GranularityK>{}),
{coord_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0});
auto b_scale_block_window = [&]() {
if constexpr(MXF8F6F4MFMA)
{
return make_tile_window(
views.at(I4),
make_tuple(number<TilePartitioner::NPerBlock / N_Pack>{},
number<TilePartitioner::KPerBlock / (GranularityK * K_Pack)>{}),
{coord_n / N_Pack, 0});
}
else
{
return make_tile_window(
views.at(I4),
make_tuple(number<FlatmmPipeline::flatNPerWarp>{},
number<FlatmmPipeline::flatKPerWarp * N_Pack * K_Pack *
XDLPerLoadScaleB / GranularityK>{}),
{coord_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0});
}
}();
return make_tuple(a_block_window, b_flat_block_window, c_block_window, scale_block_window);
return make_tuple(a_block_window,
b_flat_block_window,
c_block_window,
a_scale_block_window,
b_scale_block_window);
}
template <class MoeFlatmmKernelArgs>
@@ -831,7 +939,6 @@ struct MoeFlatmmKernel
if(coord_m >= max_token_id)
return;
static_for<0, DramMRepeat, 1>{}([&](auto m0) {
const auto row_idx =
coord_m + m0 * (TilePartitioner::MPerBlock / DramMRepeat) + a_coord[I0];
@@ -864,9 +971,10 @@ struct MoeFlatmmKernel
const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
// Run GEMM cooperatively by whole workgroup.
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_block_window = gemm_tile_windows.at(I1);
const auto& scale_block_window = gemm_tile_windows.at(I3);
const auto& a_block_window = gemm_tile_windows.at(I0);
const auto& b_block_window = gemm_tile_windows.at(I1);
const auto& a_scale_block_window = gemm_tile_windows.at(I3);
const auto& b_scale_block_window = gemm_tile_windows.at(I4);
auto a_gather_block_tile =
ck_tile::make_tile_scatter_gather(a_block_window.get_bottom_tensor_view(),
@@ -876,17 +984,32 @@ struct MoeFlatmmKernel
a_offsets); // K DRAM tile window for
auto c_block_tile = [&] {
if constexpr(MXFP4_Pipeline)
if constexpr(BMXFP4_Pipeline)
{
// MXFP4_Pipeline uses gate-up interleave 16 layout for weight
// BMXFP4_Pipeline uses gate-up interleave 16 layout for weight
// so don't need extra processing
return FlatmmPipeline{}(a_gather_block_tile,
b_block_window,
scale_block_window, // weight scale with granularityK = 32
num_loop,
kargs.k_padded_zeros,
smem_ptr_ping,
smem_ptr_pong);
if constexpr(AQUANT_Pipeline)
{
return FlatmmPipeline{}(
a_gather_block_tile,
b_block_window,
a_scale_block_window, // weight scale with granularityK = 32
b_scale_block_window, // weight scale with granularityK = 32
num_loop,
smem_ptr_ping,
smem_ptr_pong);
}
else
{
return FlatmmPipeline{}(
a_gather_block_tile,
b_block_window,
b_scale_block_window, // weight scale with granularityK = 32
num_loop,
kargs.k_padded_zeros,
smem_ptr_ping,
smem_ptr_pong);
}
}
else
{
@@ -964,7 +1087,7 @@ struct MoeFlatmmKernel
constexpr index_t ScaleMRepeat = MRepeat * kM0 * kM2;
statically_indexed_array<index_t, ScaleMRepeat> scale_m_offsets;
if constexpr(!MXFP4_Pipeline)
if constexpr(!BMXFP4_Pipeline)
static_for<0, MRepeat, 1>{}([&](auto mIter) {
static_for<0, kM0, 1>{}([&](auto m0) {
static_for<0, kM2, 1>{}([&](auto m2) {
@@ -1059,7 +1182,7 @@ struct MoeFlatmmKernel
number<1>{});
auto exp_bias_window = make_tile_window(
permute_tensor_view(exp_bias_view, number<(MXFP4_Pipeline && !IsInputGemm)>{}),
permute_tensor_view(exp_bias_view, number<(BMXFP4_Pipeline && !IsInputGemm)>{}),
make_tuple(number<TilePartitioner::MPerBlock>{},
number < IsGateUp ? TilePartitioner::NPerBlock / 2
: TilePartitioner::NPerBlock > {}),
@@ -1101,7 +1224,7 @@ struct MoeFlatmmKernel
ExpBiasBuffer exp_bias_buffer, exp_bias_up_buffer;
ExpWeightBuffer exp_weight_buffer;
if constexpr(!MXFP4_Pipeline)
if constexpr(!BMXFP4_Pipeline)
{
scale_m_window.load(scale_m_buffer);
scale_n_buffer = load_tile(scale_n_window);
@@ -1233,7 +1356,7 @@ struct MoeFlatmmKernel
auto epi_exp_bias_up = epi_tile_idx_slice(exp_bias_up_buffer, epi_m, epi_n);
static_for<0, ActVectorSize, 1>{}([&](auto idx) {
if constexpr(!MXFP4_Pipeline)
if constexpr(!BMXFP4_Pipeline)
{
gate_tensor.get_thread_buffer()[idx] *=
epi_scale_m[idx] * epi_scale_n[idx];
@@ -1260,7 +1383,7 @@ struct MoeFlatmmKernel
auto epi_exp_bias = epi_tile_idx_slice(exp_bias_buffer, epi_m, epi_n);
static_for<0, ActVectorSize, 1>{}([&](auto idx) {
if constexpr(!MXFP4_Pipeline)
if constexpr(!BMXFP4_Pipeline)
lds_tile[lds_stage].get_thread_buffer()[idx] *=
epi_scale_m[idx] * epi_scale_n[idx];
if constexpr(EnableBias)

View File

@@ -156,7 +156,7 @@ struct F16xMXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
constexpr int K_Lane = 64 / TileShape::WarpTile::at(I1); // 4
constexpr int K2 = TileShape::WarpTile::at(I2) / K_Lane; // 8
constexpr int K2 = TileShape::WarpTile::at(I2) / K_Lane; // 128 / 4 = 32
constexpr int XDL_PerThreadK = KBPerLoad / K2; // 4
constexpr int K0 = K_Lane; // 4
@@ -236,4 +236,513 @@ struct F16xMXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
}
};
struct F8xMXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy
{
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
static constexpr index_t kDramLoadPackBytes = 128;
static constexpr int MXdlPack = 2;
static constexpr int NXdlPack = 2;
static constexpr int KXdlPack = 2;
template <typename Problem>
static inline constexpr auto wg_attr_num_access = WGAttrNumAccessEnum::Single;
// std::is_same_v<remove_cvref_t<typename Problem::ADataType>, pk_fp4_t>
// ? WGAttrNumAccessEnum::Single
// : WGAttrNumAccessEnum::Double;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockFlatmm()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmDispatcher< //
ADataType,
BDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC,
false,
false,
wg_attr_num_access<Problem>>;
using BlockFlatmmPolicy = BlockFlatmmASmemBSmemCRegV1CustomPolicy< //
ADataType,
BDataType,
typename Problem::CDataType,
BlockWarps,
WarpGemm>;
return BlockFlatmmASmemBSmemCRegV1<Problem, BlockFlatmmPolicy>{};
}
template <typename Problem, typename TensorView>
CK_TILE_DEVICE static constexpr auto
MakeMXFP4_AAsyncLoadDramDescriptor(const TensorView& naive_view)
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
static_assert(MPerXdl == 16 && NPerXdl == 16);
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
const auto& naive_desc = naive_view.get_tensor_descriptor();
constexpr auto ndims = remove_cvref_t<decltype(naive_desc)>::get_num_of_dimension();
static_assert(ndims == 2, "only support 2D tensor");
const auto rows = naive_desc.get_length(number<0>{});
const auto cols = naive_desc.get_length(number<1>{});
constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
constexpr index_t K2 = GetSmemPackA<Problem>() * APackedSize; // f4=32; f8=16
constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8
const index_t K0 = cols / (K1 * K2);
const auto col_lens = make_tuple(K0, number<K1>{}, number<K2>{});
constexpr index_t M1 = 4; // so that we can use imm offset to load lds
const index_t M0 = rows / M1;
const auto row_lens = make_tuple(M0, number<M1>{});
const auto desc_0 =
make_naive_tensor_descriptor_packed(container_concat(row_lens, col_lens));
const auto desc_1 = transform_tensor_descriptor(
desc_0,
make_tuple(make_pass_through_transform(M0),
make_xor_transform(make_tuple(number<M1>{}, number<K1>{})),
make_pass_through_transform(K0),
make_pass_through_transform(number<K2>{})),
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}),
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}));
const auto desc = transform_tensor_descriptor( //
desc_1,
make_tuple(make_merge_transform_v3_division_mod(row_lens),
make_merge_transform_v3_division_mod(col_lens)),
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
// printf("A async load dram desc %d x %d: \n", desc.get_length(I0), desc.get_length(I1));
return tensor_view<typename TensorView::buffer_view,
remove_cvref_t<decltype(desc)>,
TensorView::DstInMemOp>{naive_view.buf_, desc};
}
template <typename Problem, typename TensorView>
CK_TILE_DEVICE static constexpr auto
Make_F8AAsyncLoadDramDescriptor(const TensorView& naive_view)
{
constexpr int DynamicTileOffsetFlag = 0;
constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
static_assert(MPerXdl == 16 && NPerXdl == 16);
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPack = GetSmemPackA<Problem>();
constexpr int ContiguousThreadsCntInDS_READ_16B = 4;
// implement swizzle pattern on global side
// because we can't adjust the ds_write pattern of BUFFER_LOAD_LDS.
auto swizzle_a_dram_view_1 = transform_tensor_view(
naive_view,
make_tuple(
// M-dim is not affected by swizzle pattern
make_unmerge_transform(
make_tuple(number<DynamicTileOffsetFlag>{}, number<MPerBlock>{})),
// K-dim is the swizzle dimension
make_unmerge_transform(make_tuple(number<DynamicTileOffsetFlag>{},
number<KPerBlock / KPack>{},
number<KPack>{}))),
make_tuple(sequence<0>{}, sequence<1>{}),
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}));
auto swizzle_a_dram_view_2 = transform_tensor_view(
swizzle_a_dram_view_1,
make_tuple(make_pass_through_transform(number<DynamicTileOffsetFlag>{}),
make_xor_transform(make_tuple(number<MPerBlock>{},
number<ContiguousThreadsCntInDS_READ_16B>{})),
make_pass_through_transform(number<DynamicTileOffsetFlag>{}),
make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}),
make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}));
return transform_tensor_view(
swizzle_a_dram_view_2,
make_tuple(
make_merge_transform_v3_division_mod(
make_tuple(number<DynamicTileOffsetFlag>{}, number<MPerBlock>{})),
make_merge_transform_v3_division_mod(make_tuple(number<DynamicTileOffsetFlag>{},
number<KPerBlock / KPack>{},
number<KPack>{}))),
make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeADramTileDistribution()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
constexpr index_t K2 = MPerBlock == 16
? GetSmemPackA<Problem>() * APackedSize / 4
: GetSmemPackA<Problem>() * APackedSize; // f4=32; f8=16
constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8
constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256
constexpr index_t M2 = get_warp_size() / K1; // 8
constexpr index_t M1 = BlockSize / get_warp_size(); // 4
constexpr index_t M0 = MPerBlock / (M2 * M1);
static_assert(M0 * M1 * M2 == MPerBlock, "M0, M1, M2 must cover whole MPerBlock!");
static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding< //
sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2>>, // ?,4,8 1,8,32 or 2,8,16
tuple<sequence<1>, sequence<1, 2>>, // M1 M2,K1
tuple<sequence<1>, sequence<2, 1>>,
sequence<1, 2, 2>, // M0,K0,K2
sequence<0, 0, 2>>{});
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeMXFP4_ALdsBlockDescriptor()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
static_assert(MPerXdl == 16 && NPerXdl == 16);
static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
/*reduce transform layers,compare with old ck*/
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
constexpr index_t K2 = GetSmemPackA<Problem>() * APackedSize; // f4=32; f8=16
constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8
constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256
static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!");
constexpr index_t M3 = 4; // so that we can use imm offset to load lds
constexpr index_t M2 = get_warp_size() / K1 / M3; // 2
constexpr index_t M1 = MPerXdl / (M2 * M3); // 2
constexpr index_t M0 = MPerBlock / (M1 * M2 * M3); // MPerBlock/16
static_assert(M0 * M1 * M2 * M3 == MPerBlock, "M0, M1, M2, M3 must cover whole MPerBlock!");
constexpr index_t Pad = 4 * K2; // 4 * 16
// constexpr index_t Pad = 0; // 4 * 16
// TODO: fix lds_a swizzle
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<M0>{},
number<M1>{},
number<K0>{},
number<M2>{},
number<M3>{},
number<K1>{},
number<K2>{}),
make_tuple(number<M1*(K0 * (M2 * M3 * K1 * K2) + (K0 - 1) * Pad)>{},
number<K0*(M2 * M3 * K1 * K2) + (K0 - 1) * Pad>{},
number<M2 * M3 * K1 * K2 + Pad>{},
number<M3 * K1 * K2>{},
number<K1 * K2>{},
number<K2>{},
number<1>{}),
number<K2>{},
number<1>{});
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(number<M0>{}, number<M1>{}, number<M2>{}, number<M3>{})),
make_merge_transform_v3_division_mod(
make_tuple(number<K0>{}, number<K1>{}, number<K2>{}))),
make_tuple(sequence<0, 1, 3, 4>{}, sequence<2, 5, 6>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
// return a_lds_block_desc_permuted;
return a_lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeF8_ReadALdsBlockDescriptor()
{
constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0);
constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1);
static_assert(MPerXdl == 16 && NPerXdl == 16);
/*reduce transform layers,compare with old ck*/
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPack = GetSmemPackA<Problem>();
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
number<KPack>{},
number<1>{});
constexpr int ContiguousThreadsCntInDS_READ_16B = 4;
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc_0,
make_tuple(make_xor_transform(make_tuple(number<MPerBlock>{},
number<ContiguousThreadsCntInDS_READ_16B>{})),
make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{}));
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
a_lds_block_desc_permuted,
make_tuple(make_pass_through_transform(number<MPerBlock>{}),
make_merge_transform_v3_division_mod(
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return a_lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeF8_WriteALdsBlockDescriptor()
{
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPack = GetSmemPackA<Problem>();
return make_naive_tensor_descriptor(make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
make_tuple(number<KPerBlock>{}, number<1>{}),
number<KPack>{},
number<1>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMXF4_ALDS_TileDistribution()
{
using TileShape = typename Problem::BlockGemmShape;
static_assert(TileShape::WarpTile::at(I1) == 16, "requires XDL_N == 16");
static_assert(TileShape::BlockWarps::at(I0) == 1, "requires Wave_M == 1");
constexpr int M_warps = TileShape::BlockWarps::at(number<0>{});
constexpr int N_warps = TileShape::BlockWarps::at(number<1>{});
constexpr int M_Lane = TileShape::WarpTile::at(I0); // 16
constexpr int K_Lane = 64 / M_Lane; // 4
constexpr int K_Thread = TileShape::WarpTile::at(I2) / K_Lane; // 32
// constexpr index_t num_access_v = static_cast<index_t>(wg_attr_num_access<Problem>);
constexpr index_t num_access_v = 2;
constexpr int K1 = K_Thread / num_access_v; // 16
return make_static_tile_distribution(
std::conditional_t<
num_access_v == 1,
tile_distribution_encoding<
sequence<N_warps>,
tuple<sequence<M_warps, MXdlPack, M_Lane>, sequence<K_Lane, K1>>,
tuple<sequence<1, 0>, sequence<2, 1>>,
tuple<sequence<0, 0>, sequence<0, 2>>,
sequence<2>,
sequence<1>>,
tile_distribution_encoding< //
sequence<N_warps>,
tuple<sequence<M_warps, MXdlPack, M_Lane>, sequence<num_access_v, K_Lane, K1>>,
tuple<sequence<1, 0>, sequence<2, 1>>,
tuple<sequence<0, 0>, sequence<1, 2>>,
sequence<2, 2>,
sequence<0, 2>>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_BFlatDramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape;
static_assert(TileShape::WarpTile::at(I1) == 16, "only for XDL_N == 16");
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNum = BlockSize / WaveSize;
constexpr index_t K1 = WaveSize; // threads cnt in K dim
constexpr index_t KWavePerBlk = 1;
constexpr index_t K0 = KWavePerBlk;
constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
constexpr index_t kKPerThread = 32;
constexpr index_t num_access_v = static_cast<index_t>(wg_attr_num_access<Problem>);
constexpr index_t K2 = kKPerThread / num_access_v;
return make_static_tile_distribution(
std::conditional_t< //
num_access_v == 1,
tile_distribution_encoding< //
sequence<WaveRepeat>,
tuple<sequence<NWavePerBlk, NXdlPack>, // 4 2
sequence<K0, K1, K2>>, // 1 64 32
tuple<sequence<0, 1, 2>, sequence<2>>,
tuple<sequence<0, 0, 0>, sequence<1>>,
sequence<2>,
sequence<2>>,
tile_distribution_encoding< //
sequence<WaveRepeat>,
tuple<sequence<NWavePerBlk, NXdlPack>, // 4 2
sequence<num_access_v, K0, K1, K2>>, // 2 1 64 16
tuple<sequence<0, 1, 2>, sequence<2>>,
tuple<sequence<0, 0, 1>, sequence<2>>,
sequence<2, 2>,
sequence<0, 3>>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleA_DramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNum = BlockSize / WaveSize;
constexpr index_t kMPerBlock = TileShape::BlockTile::at(I0);
constexpr index_t M_Warps = TileShape::BlockWarps::at(I0);
constexpr index_t N_Warps = TileShape::BlockWarps::at(I1);
static_assert(WaveNum == M_Warps * N_Warps, "Block warps do not match block size");
constexpr index_t M_Lanes = TileShape::WarpTile::at(I0);
constexpr index_t K_Lanes = 64 / M_Lanes;
// Y dimension (M) decomposition
constexpr index_t Y2 = M_Lanes;
constexpr index_t Y1 = M_Warps;
constexpr index_t Y0 = kMPerBlock / (MXdlPack * Y1 * Y2);
// X dimension (K) decomposition
constexpr index_t X0 = K_Lanes;
constexpr index_t X1 = 1; // packed 2x2 E8M0 data into 1 int32_t for load
return make_static_tile_distribution(
tile_distribution_encoding<sequence<N_Warps>, // repeat N_warps
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
tuple<sequence<1, 0>, sequence<2, 1>>,
tuple<sequence<1, 0>, sequence<0, 2>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleB_DramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t WaveSize = get_warp_size();
constexpr index_t WaveNum = BlockSize / WaveSize;
constexpr index_t kNPerBlock = TileShape::BlockTile::at(I1);
constexpr index_t M_Warps = TileShape::BlockWarps::at(I0);
constexpr index_t N_Warps = TileShape::BlockWarps::at(I1);
static_assert(WaveNum == M_Warps * N_Warps, "Block warps do not match block size");
constexpr index_t N_Lanes = TileShape::WarpTile::at(I1);
constexpr index_t K_Lanes = 64 / N_Lanes;
// Y dimension (M) decomposition
constexpr index_t Y2 = N_Lanes;
constexpr index_t Y1 = N_Warps;
constexpr index_t Y0 = kNPerBlock / (NXdlPack * Y1 * Y2);
// X dimension (K) decomposition
constexpr index_t X0 = K_Lanes;
constexpr index_t X1 = 1; // packed 2x2 E8M0 data into 1 int32_t for load
return make_static_tile_distribution(
tile_distribution_encoding<sequence<M_Warps>, // ?
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
tuple<sequence<0, 1>, sequence<2, 1>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleA_FlatDramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape;
constexpr index_t M_Warp = TileShape::BlockWarps::at(number<0>{});
constexpr index_t K_Lane = 64 / TileShape::WarpTile::at(I0);
constexpr index_t M_Lane = TileShape::WarpTile::at(I0);
constexpr index_t N_Wrap = TileShape::BlockWarps::at(number<1>{});
constexpr index_t MWavePerBlk = M_Warp;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<N_Wrap>, // ?
tuple<sequence<MWavePerBlk, M_Lane>, // second direction
sequence<K_Lane, 1>>, // first direction
tuple<sequence<1, 0>, sequence<2, 1>>, // which direction
tuple<sequence<0, 0>, sequence<0, 1>>, // which index
// <repeat, vec_load>
sequence<2>,
sequence<1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleB_FlatDramTileDistribution()
{
using TileShape = typename Problem::BlockGemmShape;
constexpr index_t N_Warp = TileShape::BlockWarps::at(number<1>{});
constexpr index_t K_Lane = 64 / TileShape::WarpTile::at(I1);
constexpr index_t N_Lane = TileShape::WarpTile::at(I1);
constexpr index_t M_Wrap = TileShape::BlockWarps::at(number<0>{});
constexpr index_t NWavePerBlk = N_Warp;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<M_Wrap>, // ?
tuple<sequence<NWavePerBlk, N_Lane>, // second direction
sequence<K_Lane, 1>>, // first direction
tuple<sequence<0, 1>, sequence<2, 1>>, // which direction
tuple<sequence<0, 0>, sequence<0, 1>>, // which index
// <repeat, vec_load>
sequence<2>,
sequence<1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr index_t APackedSize = numeric_traits<ADataType>::PackedSize;
return sizeof(ADataType) *
MakeMXFP4_ALdsBlockDescriptor<Problem>().get_element_space_size() / APackedSize;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return GetSmemSizeA<Problem>();
}
};
} // namespace ck_tile

View File

@@ -309,25 +309,6 @@ using WarpGemmMfma_f32_16x16x64_bf8_bf8 = WarpGemmImpl<WarpGemmAttributeMfmaIter
template <typename A, typename B, WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_16x16x128_f8f6f4 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<A, B>, AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_16x16x128_fp8_fp8 = WarpGemmImpl< //
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<fp8_t, fp8_t>,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_16x16x128_fp8_bf8 = WarpGemmImpl< //
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<fp8_t, bf8_t>,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_16x16x128_bf8_fp8 = WarpGemmImpl< //
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<bf8_t, fp8_t>,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_16x16x128_bf8_bf8 = WarpGemmImpl< //
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<bf8_t, bf8_t>,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfma_f32_16x16x128_fp8_fp8_CTransposed =

View File

@@ -0,0 +1,232 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <iostream>
#include <sstream>
#include <unordered_map>
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/library/utility/host_tensor.hpp"
namespace ck {
namespace tensor_operation {
namespace host {
template <typename ADataType,
typename BDataType,
typename CDataType,
typename AccDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
typename ComputeTypeA = AccDataType,
typename ComputeTypeB = AccDataType>
struct ReferenceMoeGemm1BlockScaleSplitK : public device::BaseOperator
{
// Argument
struct Argument : public device::BaseArgument
{
Argument(const Tensor<ck::index_t>& sorted_token_ids,
const Tensor<ck::index_t>& expert_ids,
const Tensor<ck::index_t>& max_token_id,
const index_t sorted_tile_size,
const Tensor<ADataType>& a_t_k,
const Tensor<BDataType>& b_e_n_k,
Tensor<CDataType>& c_t_k_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
: sorted_token_ids_{sorted_token_ids},
expert_ids_{expert_ids},
max_token_id_{max_token_id},
sorted_tile_size_{sorted_tile_size},
a_t_k_{a_t_k},
b_e_n_k_{b_e_n_k},
c_t_k_n_{c_t_k_n},
a_element_op_{a_element_op},
b_element_op_{b_element_op},
c_element_op_{c_element_op}
{
}
const Tensor<ck::index_t>& sorted_token_ids_;
const Tensor<ck::index_t>& expert_ids_;
const Tensor<ck::index_t>& max_token_id_;
index_t sorted_tile_size_;
const Tensor<ADataType>& a_t_k_;
const Tensor<BDataType>& b_e_n_k_;
Tensor<CDataType>& c_t_k_n_;
AElementwiseOperation a_element_op_;
BElementwiseOperation b_element_op_;
CElementwiseOperation c_element_op_;
};
// Invoker
struct Invoker : public device::BaseInvoker
{
using Argument = ReferenceMoeGemm1BlockScaleSplitK::Argument;
float Run(const Argument& arg)
{
const int full_n = arg.c_t_k_n_.mDesc.GetLengths()[2];
auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = arg.a_t_k_.mDesc.GetLengths()[1];
AccDataType v_acc{0};
ComputeTypeA v_a{0};
ComputeTypeB v_b{0};
const int t = arg.sorted_token_ids_(m) & 0xffffff;
const int topk_id = (arg.sorted_token_ids_(m) & 0xff000000) >> 24;
const int e = arg.expert_ids_(m / arg.sorted_tile_size_);
const int token_cnt = arg.a_t_k_.mDesc.GetLengths()[0];
if(t < token_cnt)
{
for(int k = 0; k < K; ++k)
{
if constexpr(is_same_v<ADataType, pk_i4_t>)
{
uint8_t i4x2 = arg.a_t_k_(t, k).data;
uint8_t i4 = 0;
if(k % 2 == 1)
i4 = (i4x2 >> 0) & 0xf;
else
i4 = (i4x2 >> 4) & 0xf;
#if CK_USE_PK4_LAYOUT_SHUFFLE
v_a = i4_to_f32_gfx9(i4);
#else
v_a = i4 - 8;
#endif
}
else
{
arg.a_element_op_(v_a, arg.a_t_k_(t, k));
}
// same for B matrix
if constexpr(is_same_v<BDataType, pk_i4_t>)
{
uint8_t i4x2 = arg.b_e_n_k_(e, k, n).data;
uint8_t i4 = 0;
if(k % 2 == 1)
{
i4 = (i4x2 >> 0) & 0xf;
}
else
{
i4 = (i4x2 >> 4) & 0xf;
}
#if CK_USE_PK4_LAYOUT_SHUFFLE
v_b = i4_to_f32_gfx9(i4);
#else
v_b = i4 - 8;
#endif
}
else
{
arg.b_element_op_(v_b, arg.b_e_n_k_(e, k, n));
}
v_acc +=
ck::type_convert<AccDataType>(v_a) * ck::type_convert<AccDataType>(v_b);
}
CDataType v_c{0};
arg.c_element_op_(v_c, v_acc);
arg.c_t_k_n_(t, topk_id, n) = v_c;
}
};
const ck::index_t max_token_id = arg.max_token_id_(0);
make_ParallelTensorFunctor(f_mk_kn_mn, max_token_id, full_n)(
std::thread::hardware_concurrency());
return 0;
}
float Run(const device::BaseArgument* p_arg,
const StreamConfig& /* stream_config */ = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg));
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
bool IsSupportedArgument(const device::BaseArgument*) override { return true; }
static auto MakeArgument(const Tensor<ck::index_t>& sorted_token_ids,
const Tensor<ck::index_t>& expert_ids,
const Tensor<ck::index_t>& max_token_id,
const index_t sorted_tile_size,
const Tensor<ADataType>& a_t_k,
const Tensor<BDataType>& b_e_n_k,
Tensor<CDataType>& c_t_k_n,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{sorted_token_ids,
expert_ids,
max_token_id,
sorted_tile_size,
a_t_k,
b_e_n_k,
c_t_k_n,
a_element_op,
b_element_op,
c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
virtual std::unique_ptr<device::BaseInvoker> MakeInvokerPointer()
{
return std::make_unique<Invoker>(Invoker{});
}
std::string GetTypeString() const override
{
auto str = std::stringstream();
// clang-format off
str << "ReferenceMoeGemm1BlaockScaleSplitK"
<< std::endl;
// clang-format on
return str.str();
}
static float i4_to_f32_gfx9(uint8_t i4)
{
static std::unordered_map<uint8_t, float> u = {{0b1000, -0.5000f},
{0b1001, -0.4375f},
{0b1010, -0.3750f},
{0b1011, -0.3125f},
{0b1100, -0.2500f},
{0b1101, -0.1875f},
{0b1110, -0.1250f},
{0b1111, -0.0625f},
{0b0, +0.0000f},
{0b1, +0.0625f},
{0b10, +0.1250f},
{0b11, +0.1875f},
{0b100, +0.2500f},
{0b101, +0.3125f},
{0b110, +0.3750f},
{0b111, +0.4375f}};
return u[i4];
}
};
} // namespace host
} // namespace tensor_operation
} // namespace ck