[CK] Mxfp4 moe blockscale buf2lds version support (#2455)

* change cshuffle size

* added mxfp4 moe async buffer loading without B preshuffle

* added mx moe B shuffling + scale shuffling (async loads)

* minor fix

---------

Co-authored-by: mtgu0705 <mtgu@amd.com>

[ROCm/composable_kernel commit: 7998ae8969]
This commit is contained in:
Mingtao Gu
2025-07-06 15:42:00 +08:00
committed by GitHub
parent a53dbafec9
commit 7face91352
19 changed files with 10677 additions and 3473 deletions

View File

@@ -22,16 +22,35 @@ add_example_dependencies(example_gemm_mx example_moe_gemm1_xdl_mx_fp4_bns)
add_example_executable(example_moe_gemm2_xdl_mx_fp4_bns moe_gemm2_xdl_mx_fp4_bns.cpp)
add_example_dependencies(example_gemm_mx example_moe_gemm2_xdl_mx_fp4_bns)
add_example_executable(example_moe_gemm1_xdl_mx_fp4 moe_gemm1_xdl_mx_fp4.cpp)
add_example_dependencies(example_gemm_mx example_moe_gemm1_xdl_mx_fp4)
add_example_executable(example_moe_gemm2_xdl_mx_fp4 moe_gemm2_xdl_mx_fp4.cpp)
add_example_dependencies(example_gemm_mx example_moe_gemm2_xdl_mx_fp4)
add_example_executable(example_moe_gemm1_xdl_mx_fp4_bpreshuffle moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp)
add_example_dependencies(example_gemm_mx example_moe_gemm1_xdl_mx_fp4_bpreshuffle)
add_example_executable(example_moe_gemm2_xdl_mx_fp4_bpreshuffle moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp)
add_example_dependencies(example_gemm_mx example_moe_gemm2_xdl_mx_fp4_bpreshuffle)
set(FP4_MXGEMM_OPTIONS)
list(APPEND FP4_MXGEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --amdgpu-use-amdgpu-trackers=1")
example_compile_options(example_gemm_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS})
example_compile_options(example_gemm_mx_fp4_bpreshuffle PRIVATE ${FP4_MXGEMM_OPTIONS})
example_compile_options(example_moe_gemm1_xdl_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS})
example_compile_options(example_moe_gemm2_xdl_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS})
# mx moe B no-shuffling + scale shuffling
example_compile_options(example_moe_gemm1_xdl_mx_fp4_bns PRIVATE ${FP4_MXGEMM_OPTIONS})
example_compile_options(example_moe_gemm2_xdl_mx_fp4_bns PRIVATE ${FP4_MXGEMM_OPTIONS})
# mx moe B no-shuffling + scale shuffling (async loads)
example_compile_options(example_moe_gemm1_xdl_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS})
example_compile_options(example_moe_gemm2_xdl_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS})
# mx moe B shuffling + scale shuffling (async loads)
example_compile_options(example_moe_gemm1_xdl_mx_fp4_bpreshuffle PRIVATE ${FP4_MXGEMM_OPTIONS})
example_compile_options(example_moe_gemm2_xdl_mx_fp4_bpreshuffle PRIVATE ${FP4_MXGEMM_OPTIONS})
set(FP8_MXGEMM_OPTIONS)
list(APPEND FP8_MXGEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32")
example_compile_options(example_gemm_mx_fp8 PRIVATE ${FP8_MXGEMM_OPTIONS})

View File

@@ -0,0 +1,548 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm1.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F4 = ck::f4x2_pk_t;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using XDataType = ck::e8m0_bexp_t;
using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = F4;
using A1DataType = XPackedDataType;
using B0DataType = F4;
using B1DataType = XPackedDataType;
using EDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F16;
using D0DataType = F32;
using D1DataType = F32;
using D2DataType = F32;
using DsDataType = ck::Tuple<D0DataType, D1DataType, D2DataType>;
using A0Layout = Row;
using B0Layout = Col;
using ELayout = Row;
using D0Layout = Row;
using D1Layout = Col;
using D2Layout = ELayout;
using DsLayout = ck::Tuple<D0Layout, D1Layout, D2Layout>;
// d0: ascale, d1: bscale, d2:expert weight
struct MulABScaleExpertWeight
{
template <typename E, typename C, typename D0, typename D1, typename D2>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const;
// for real kernel use
template <>
__host__ __device__ constexpr void operator()<EDataType, F16, float, float, float>(
EDataType& e, const F16& c, const float& d0, const float& d1, const float& d2) const
{
(void)d0;
(void)d1;
(void)d2;
e = ck::type_convert<EDataType>(c);
}
// for reference cpu
template <>
__host__ __device__ constexpr void operator()<float, float, float, float, float>(
float& e, const float& c, const float& d0, const float& d1, const float& d2) const
{
// for reference cpu
(void)d0;
(void)d1;
(void)d2;
e = ck::type_convert<EDataType>(c);
}
};
using CDEElementOp = MulABScaleExpertWeight;
// A, B Scale preshuffle
template <bool KLast>
void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K)
{
int MNXdlPack = 2;
int KXdlPack = 2;
int XdlMNThread = 16;
int XdlKThread = 64 / XdlMNThread;
int K0 = K / KXdlPack / XdlKThread; // KRepeat
// The 4 16x128 building blocks will be packed into 1 32x256 for F4
// The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4
// unfold the MN32xK(256/32) scale buffer
// 4 16 2 2
// To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack
// Then, MNRepeat->KRepeat
for(int n = 0; n < MN; ++n)
{
for(int k = 0; k < K; ++k)
{
int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat
int tempn = n % (XdlMNThread * MNXdlPack);
int n1 = tempn % XdlMNThread; // i XdlMNThread
int n2 = tempn / XdlMNThread; // i MNXdlPack
int k0 = k / (XdlKThread * KXdlPack); // i KRepeat
int tempk = k % (XdlKThread * KXdlPack);
int k1 = tempk % XdlKThread; // i XdlKThread
int k2 = tempk / XdlKThread; // i KXdlPack
int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 +
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack +
k2 * MNXdlPack + n2;
// src[n * K + k] = ck::type_convert<ck::e8m0_bexp_t>(static_cast<float>(powf(2.0f, n2 +
// k2 * MNXdlPack)));
if constexpr(KLast)
dst[outputIndex] = src[n * K + k];
else
dst[outputIndex] = src[k * MN + n];
}
}
}
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = MulABScaleExpertWeight;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
constexpr ck::index_t DataPackedSize = 2; // Packed representation of data
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
static constexpr ck::index_t Nswizzle = false;
static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul
static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t NPerBlock = 64;
static constexpr ck::index_t BlockSize = 256;
static constexpr bool MulRoutedWeight = true;
// clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMX<
A0Layout, B0Layout, DsLayout, ELayout,
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
ScaleBlockSize, BlockSize,
MPerBlock, NPerBlock, KPerBlock,
16, 16,
16, 16,
4, 2,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
2, 2, S<1, 32, 1, 8>, S<8, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3,
ActOP, Nswizzle, true, MulRoutedWeight, ck::index_t, A0DataType>;
// clang-format on
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
// per expert:
// GEMM shape
constexpr ck::index_t sorted_tile_num = 13;
constexpr ck::index_t valid_tile_num = sorted_tile_num;
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;
ck::index_t N = 6144;
ck::index_t K = 4096;
ck::index_t experts = 8;
ck::index_t tokens = 832;
ck::index_t topk = 2;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
// use default case
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 7)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
N = std::stoi(argv[4]);
K = std::stoi(argv[5]);
tokens = std::stoi(argv[6]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 6: N, K, tokens\n");
exit(0);
}
if(K % ScaleBlockSize != 0)
{
throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize.");
};
ck::index_t StrideA = K;
ck::index_t StrideB = K;
ck::index_t StrideE = N;
ck::index_t Scale_Stride_AM = (K + ScaleBlockSize - 1) / ScaleBlockSize;
ck::index_t Scale_Stride_BN = (K + ScaleBlockSize - 1) / ScaleBlockSize;
constexpr ck::index_t NumDTensor = DsDataType::Size();
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0, 0};
ck::index_t KBatch = 1;
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({sorted_tile_num + 1}));
max_token_id.mData[0] = valid_size;
if(tokens * topk > valid_size)
{
printf("err config, tokens * topk > valid_size\n");
exit(-1);
}
for(int i = 0; i < sorted_tile_num; i++)
{
expert_ids.mData[i] = i / ck::math::integer_divide_ceil(valid_tile_num, experts);
}
int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num;
int tokenid = 0;
for(int i = 0; i < sorted_size; i++)
{
int tile_off = i % MPerBlock;
if(tile_off < token_per_tile)
{
sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24);
tokenid++;
}
else
{
sorted_token_ids.mData[i] = tokens;
}
}
expert_ids.savetxt("expert_ids.txt", "int");
sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
Tensor<A0DataType> a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1}));
Tensor<XDataType> a1_t_k(HostTensorDescriptor(
{tokens, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}));
Tensor<XDataType> b1_e_n_k(
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2},
{(N * 2 * Scale_Stride_BN), 1, Scale_Stride_BN}));
// A, B Scale preshuffle
Tensor<XDataType> a_scale_sorted(HostTensorDescriptor(
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
Tensor<XDataType> a_scale_preshuffled(HostTensorDescriptor(
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
Tensor<XDataType> b_scale_preshuffled(
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2},
{N * 2 * Scale_Stride_BN, 1, Scale_Stride_BN}));
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
Tensor<EDataType> e_t_k_n_host_result(
HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
Tensor<EDataType> e_t_k_n_device_result(
HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
e_t_k_n_device_result.SetZero();
std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl;
std::cout << "a1_t_k: " << a1_t_k.mDesc << std::endl;
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl;
std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl;
std::cout << "e_t_k_n: " << e_t_k_n_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-1, 1});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-1, 1});
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0, 1.0});
break;
case 2:
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
a1_t_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{0.1f});
break;
case 3:
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-1, 1});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-1, 1});
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 4:
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 5.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 5:
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{1});
break;
case 6:
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 7:
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{0.5f});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{1.5f});
a1_t_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{1.0f});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{1.0f});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{0.1f});
break;
default:
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
}
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.GetElementSpaceSize());
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.GetElementSpaceSize());
DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.GetElementSpaceSize());
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.GetElementSpaceSize());
DeviceMem a1_device_buf(sizeof(XDataType) * a_scale_sorted.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.GetElementSpaceSize());
DeviceMem b1_device_buf(sizeof(XDataType) * b1_e_n_k.GetElementSpaceSize());
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_t_k_n_device_result.GetElementSpaceSize());
// A scale sorted
for(int i = 0; i < sorted_size; i++)
{
int token_id = sorted_token_ids.mData[i] & 0x00FFFFFF;
for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; k++)
{
if(token_id == tokens)
{
a_scale_sorted(i, k) = ck::type_convert<XDataType>(0);
}
else
{
a_scale_sorted(i, k) = a1_t_k(token_id, k);
}
}
}
// A/B scale shuffle
preShuffleScaleBuffer<ck::is_same_v<A0Layout, Row>>(a_scale_sorted.mData.data(),
a_scale_preshuffled.mData.data(),
sorted_size,
K / ScaleBlockSize);
preShuffleScaleBuffer<ck::is_same_v<B0Layout, Col>>(b1_e_n_k.mData.data(),
b_scale_preshuffled.mData.data(),
N * 2 * experts,
K / ScaleBlockSize);
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
expert_ids_dev.ToDevice(expert_ids.mData.data());
max_token_id_dev.ToDevice(max_token_id.mData.data());
a0_device_buf.ToDevice(a0_t_k.mData.data());
b0_device_buf.ToDevice(b0_e_n_k.mData.data());
a1_device_buf.ToDevice(a_scale_preshuffled.mData.data());
b1_device_buf.ToDevice(b_scale_preshuffled.mData.data());
d2_device_buf.ToDevice(d2_e_n.mData.data());
e_device_buf.ToDevice(e_t_k_n_device_result.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
// do GEMM
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument = device_op.MakeArgument(
sorted_token_ids_dev.GetDeviceBuffer(),
expert_ids_dev.GetDeviceBuffer(),
max_token_id_dev.GetDeviceBuffer(),
a0_device_buf.GetDeviceBuffer(),
a1_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(),
b1_device_buf.GetDeviceBuffer(),
std::array<const void*, NumDTensor>{nullptr, nullptr, d2_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
tokens,
topk,
sorted_size,
N,
K,
StrideA,
Scale_Stride_AM,
StrideB,
Scale_Stride_BN,
StrideDs,
StrideE,
KBatch,
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"))
{
std::cout << "This kernel support gfx942 and gfx950 only" << std::endl;
}
if(time_kernel)
{
// not result correct here because output buf not setzero
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop =
// FMA * tokens * N * (Gate+Up) * topk * K +
// FMA * tokens * N * (Gate+Up) * topk * (K/BlockScale)
std::size_t(2) * tokens * N * 2 * topk * K +
std::size_t(2) * tokens * N * 2 * topk * K / ScaleBlockSize;
std::size_t num_btype = sizeof(A0DataType) / 2 * tokens * topk * K +
sizeof(B0DataType) / 2 * K * N * 2 * experts +
sizeof(XDataType) * tokens * topk * K / ScaleBlockSize +
sizeof(XDataType) * K / ScaleBlockSize * N * 2 * experts +
sizeof(EDataType) * tokens * topk * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << device_op.GetTypeString() << std::endl;
}
if(do_verification)
{
// gemm2 use atomic, so need to reinit outputs
e_device_buf.ToDevice(e_t_k_n_device_result.mData.data());
invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1});
Tensor<float> c_t_k_n({tokens, topk, N}, {topk * N, N, 1});
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceMoeMXGemm1<A0DataType,
XDataType,
B0DataType,
XDataType,
float, // CShuffleDataType,
D2DataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough,
ActOP,
MulRoutedWeight>;
auto ref_moe_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_moe_gemm.MakeInvoker();
auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids,
expert_ids,
max_token_id,
MPerBlock,
a0_t_k,
a1_t_k,
b0_e_n_k,
b1_e_n_k,
d2_e_n,
c_t_k_n,
PassThrough{},
PassThrough{},
PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < valid_size; ++m)
{
const int fuse_t = sorted_token_ids.mData[m];
const int t = fuse_t & 0xffffff;
const int topk_id = (fuse_t & 0xff000000) >> 24;
if(t >= tokens)
{
continue;
}
for(int n = 0; n < N; ++n)
{
e_t_k_n_host_result(t, topk_id, n) =
ck::type_convert<EDataType>(c_t_k_n(t, topk_id, n));
}
}
e_device_buf.FromDevice(e_t_k_n_device_result.mData.data());
auto status =
ck::utils::check_err(
e_t_k_n_device_result, e_t_k_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1)
? 0
: 1;
if(status == 0)
{
printf("Validation Pass.\n");
}
return status;
}
return 0;
}

View File

@@ -0,0 +1,574 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bpreshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm1.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F4 = ck::f4x2_pk_t;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using XDataType = ck::e8m0_bexp_t;
using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t
using I64 = int64_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = F4;
using A1DataType = XPackedDataType;
using B0DataType = F4;
using B1DataType = XPackedDataType;
using EDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F16;
using D0DataType = F32;
using D1DataType = F32;
using D2DataType = F32;
using DsDataType = ck::Tuple<D0DataType, D1DataType, D2DataType>;
using A0Layout = Row;
using B0Layout = Col;
using ELayout = Row;
using D0Layout = Row;
using D1Layout = Col;
using D2Layout = ELayout;
using DsLayout = ck::Tuple<D0Layout, D1Layout, D2Layout>;
// d0: ascale, d1: bscale, d2:expert weight
struct MulABScaleExpertWeight
{
template <typename E, typename C, typename D0, typename D1, typename D2>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const;
// for real kernel use
template <>
__host__ __device__ constexpr void operator()<EDataType, F16, float, float, float>(
EDataType& e, const F16& c, const float& d0, const float& d1, const float& d2) const
{
(void)d0;
(void)d1;
(void)d2;
e = ck::type_convert<EDataType>(c);
}
// for reference cpu
template <>
__host__ __device__ constexpr void operator()<float, float, float, float, float>(
float& e, const float& c, const float& d0, const float& d1, const float& d2) const
{
// for reference cpu
(void)d0;
(void)d1;
(void)d2;
e = ck::type_convert<EDataType>(c);
}
};
using CDEElementOp = MulABScaleExpertWeight;
// B preshuffle
void preShuffleBuffer(const F4* src, F4* dst, int N, int K, int NXdl)
{
int KPack = 16;
int NLane = NXdl;
int KLane = 64 / NLane;
int K_pk = K / 2;
int K0 = K_pk / (KLane * KPack);
// K -> K0 KLane KPack
// N -> N0 NLane
// N, K -> N0 K0 KLane NLane KPack
I64 tempk;
for(I64 n = 0; n < N; ++n)
{
for(I64 k = 0; k < K_pk; ++k)
{
I64 n0 = n / NLane;
I64 n1 = n % NLane;
I64 k0 = k / (KLane * KPack);
tempk = k % (KLane * KPack);
I64 k1 = tempk / KPack;
I64 k2 = tempk % KPack;
I64 outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
k1 * KPack * NLane + n1 * KPack + k2;
dst[outputIndex] = src[n * K_pk + k];
}
}
}
// A, B Scale preshuffle
template <bool KLast>
void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K)
{
int MNXdlPack = 2;
int KXdlPack = 2;
int XdlMNThread = 16;
int XdlKThread = 64 / XdlMNThread;
int K0 = K / KXdlPack / XdlKThread; // KRepeat
// The 4 16x128 building blocks will be packed into 1 32x256 for F4
// The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4
// unfold the MN32xK(256/32) scale buffer
// 4 16 2 2
// To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack
// Then, MNRepeat->KRepeat
for(int n = 0; n < MN; ++n)
{
for(int k = 0; k < K; ++k)
{
int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat
int tempn = n % (XdlMNThread * MNXdlPack);
int n1 = tempn % XdlMNThread; // i XdlMNThread
int n2 = tempn / XdlMNThread; // i MNXdlPack
int k0 = k / (XdlKThread * KXdlPack); // i KRepeat
int tempk = k % (XdlKThread * KXdlPack);
int k1 = tempk % XdlKThread; // i XdlKThread
int k2 = tempk / XdlKThread; // i KXdlPack
int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 +
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack +
k2 * MNXdlPack + n2;
// src[n * K + k] = ck::type_convert<ck::e8m0_bexp_t>(static_cast<float>(powf(2.0f, n2 +
// k2 * MNXdlPack)));
if constexpr(KLast)
dst[outputIndex] = src[n * K + k];
else
dst[outputIndex] = src[k * MN + n];
}
}
}
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = MulABScaleExpertWeight;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
constexpr ck::index_t DataPackedSize = 2; // Packed representation of data
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
static constexpr ck::index_t Nswizzle = false;
static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul
static constexpr ck::index_t MPerBlock = 128;
static constexpr bool MulRoutedWeight = true;
// clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMXBPreShuffle<
A0Layout, B0Layout, DsLayout, ELayout,
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
ScaleBlockSize, 256,
MPerBlock, 64, KPerBlock,
16, 16,
16, 16,
4, 2,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
2, 2, S<1, 32, 1, 8>, S<8, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, true, MulRoutedWeight, ck::index_t, A0DataType>;
// clang-format on
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
// per expert:
// GEMM shape
constexpr ck::index_t sorted_tile_num = 13;
constexpr ck::index_t valid_tile_num = sorted_tile_num;
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;
ck::index_t N = 6144;
ck::index_t K = 4096;
ck::index_t experts = 8;
ck::index_t tokens = 832;
ck::index_t topk = 2;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
// use default case
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 7)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
N = std::stoi(argv[4]);
K = std::stoi(argv[5]);
tokens = std::stoi(argv[6]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 6: N, K, tokens\n");
exit(0);
}
if(K % ScaleBlockSize != 0)
{
throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize.");
};
ck::index_t StrideA = K;
ck::index_t StrideB = K;
ck::index_t StrideE = N;
ck::index_t Scale_Stride_AM = (K + ScaleBlockSize - 1) / ScaleBlockSize;
ck::index_t Scale_Stride_BN = (K + ScaleBlockSize - 1) / ScaleBlockSize;
constexpr ck::index_t NumDTensor = DsDataType::Size();
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0, 0};
ck::index_t KBatch = 1;
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({sorted_tile_num + 1}));
max_token_id.mData[0] = valid_size;
if(tokens * topk > valid_size)
{
printf("err config, tokens * topk > valid_size\n");
exit(-1);
}
for(int i = 0; i < sorted_tile_num; i++)
{
expert_ids.mData[i] = i / ck::math::integer_divide_ceil(valid_tile_num, experts);
}
int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num;
int tokenid = 0;
for(int i = 0; i < sorted_size; i++)
{
int tile_off = i % MPerBlock;
if(tile_off < token_per_tile)
{
sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24);
tokenid++;
}
else
{
sorted_token_ids.mData[i] = tokens;
}
}
Tensor<A0DataType> a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1}));
Tensor<XDataType> a1_t_k(HostTensorDescriptor(
{tokens, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}));
Tensor<XDataType> b1_e_n_k(
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2},
{(N * 2 * Scale_Stride_BN), 1, Scale_Stride_BN}));
// B preshuffle
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}));
// A, B Scale preshuffle
Tensor<XDataType> a_scale_sorted(HostTensorDescriptor(
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
Tensor<XDataType> a_scale_preshuffled(HostTensorDescriptor(
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
Tensor<XDataType> b_scale_preshuffled(
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2},
{N * 2 * Scale_Stride_BN, 1, Scale_Stride_BN}));
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
Tensor<EDataType> e_t_k_n_host_result(
HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
Tensor<EDataType> e_t_k_n_device_result(
HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
e_t_k_n_device_result.SetZero();
std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl;
std::cout << "a1_t_k: " << a1_t_k.mDesc << std::endl;
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl;
std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl;
std::cout << "e_t_k_n: " << e_t_k_n_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-1, 1});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-1, 1});
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0, 1.0});
break;
case 2:
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
a1_t_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{0.1f});
break;
case 3:
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-1, 1});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-1, 1});
a1_t_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{0.1f});
break;
case 4:
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{0.1f});
break;
case 5:
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{0.1f});
break;
case 6:
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
default:
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
}
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.GetElementSpaceSize());
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.GetElementSpaceSize());
DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.GetElementSpaceSize());
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.GetElementSpaceSize());
DeviceMem a1_device_buf(sizeof(XDataType) * a_scale_sorted.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.GetElementSpaceSize());
DeviceMem b1_device_buf(sizeof(XDataType) * b1_e_n_k.GetElementSpaceSize());
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_t_k_n_device_result.GetElementSpaceSize());
// A scale sorted
for(int i = 0; i < sorted_size; i++)
{
int token_id = sorted_token_ids.mData[i] & 0x00FFFFFF;
for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; k++)
{
if(token_id == tokens)
{
a_scale_sorted(i, k) = ck::type_convert<XDataType>(0);
}
else
{
a_scale_sorted(i, k) = a1_t_k(token_id, k);
}
}
}
// A/B scale shuffle
preShuffleScaleBuffer<ck::is_same_v<A0Layout, Row>>(a_scale_sorted.mData.data(),
a_scale_preshuffled.mData.data(),
sorted_size,
K / ScaleBlockSize);
preShuffleScaleBuffer<ck::is_same_v<B0Layout, Col>>(b1_e_n_k.mData.data(),
b_scale_preshuffled.mData.data(),
N * 2 * experts,
K / ScaleBlockSize);
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
expert_ids_dev.ToDevice(expert_ids.mData.data());
max_token_id_dev.ToDevice(max_token_id.mData.data());
a0_device_buf.ToDevice(a0_t_k.mData.data());
a1_device_buf.ToDevice(a_scale_preshuffled.mData.data());
b1_device_buf.ToDevice(b_scale_preshuffled.mData.data());
d2_device_buf.ToDevice(d2_e_n.mData.data());
e_device_buf.ToDevice(e_t_k_n_device_result.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
// do GEMM
auto device_op = DeviceOpInstance{};
preShuffleBuffer(b0_e_n_k.mData.data(),
b0_preshuffled.mData.data(),
N * 2 * experts,
K,
device_op.GetPreShuffleParameters());
b0_device_buf.ToDevice(b0_preshuffled.mData.data());
auto invoker = device_op.MakeInvoker();
auto argument = device_op.MakeArgument(
sorted_token_ids_dev.GetDeviceBuffer(),
expert_ids_dev.GetDeviceBuffer(),
max_token_id_dev.GetDeviceBuffer(),
a0_device_buf.GetDeviceBuffer(),
a1_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(),
b1_device_buf.GetDeviceBuffer(),
std::array<const void*, NumDTensor>{nullptr, nullptr, d2_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
tokens,
topk,
sorted_size,
N,
K,
StrideA,
Scale_Stride_AM,
StrideB,
Scale_Stride_BN,
StrideDs,
StrideE,
KBatch,
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"))
{
std::cout << "This kernel support gfx942 and gfx950 only" << std::endl;
}
if(time_kernel)
{
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop =
// FMA * tokens * N * (Gate+Up) * topk * K +
// FMA * tokens * N * (Gate+Up) * topk * (K/BlockScale)
std::size_t(2) * tokens * N * 2 * topk * K +
std::size_t(2) * tokens * N * 2 * topk * K / ScaleBlockSize;
std::size_t num_btype = sizeof(A0DataType) / 2 * tokens * topk * K +
sizeof(B0DataType) / 2 * K * N * 2 * experts +
sizeof(XDataType) * tokens * topk * K / ScaleBlockSize +
sizeof(XDataType) * K / ScaleBlockSize * N * 2 * experts +
sizeof(EDataType) * tokens * topk * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << device_op.GetTypeString() << std::endl;
}
if(do_verification)
{
invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1});
Tensor<float> c_t_k_n({tokens, topk, N}, {topk * N, N, 1});
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceMoeMXGemm1<A0DataType,
XDataType,
B0DataType,
XDataType,
float, // CShuffleDataType,
D2DataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough,
ActOP,
MulRoutedWeight>;
auto ref_moe_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_moe_gemm.MakeInvoker();
auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids,
expert_ids,
max_token_id,
MPerBlock,
a0_t_k,
a1_t_k,
b0_e_n_k,
b1_e_n_k,
d2_e_n,
c_t_k_n,
PassThrough{},
PassThrough{},
PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < valid_size; ++m)
{
const int fuse_t = sorted_token_ids.mData[m];
const int t = fuse_t & 0xffffff;
const int topk_id = (fuse_t & 0xff000000) >> 24;
if(t >= tokens)
{
continue;
}
for(int n = 0; n < N; ++n)
{
e_t_k_n_host_result(t, topk_id, n) =
ck::type_convert<EDataType>(c_t_k_n(t, topk_id, n));
}
}
e_device_buf.FromDevice(e_t_k_n_device_result.mData.data());
auto status =
ck::utils::check_err(
e_t_k_n_device_result, e_t_k_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1)
? 0
: 1;
if(status == 0)
{
printf("Validation Pass.\n");
}
return status;
}
return 0;
}

View File

@@ -0,0 +1,542 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm2.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F4 = ck::f4x2_pk_t;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using XDataType = ck::e8m0_bexp_t;
using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = F4;
using A1DataType = XPackedDataType;
using B0DataType = F4;
using B1DataType = XPackedDataType;
using EDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F16;
using D0DataType = F32;
using D1DataType = F32;
using D2DataType = F32;
using DsDataType = ck::Tuple<D0DataType, D1DataType, D2DataType>;
using A0Layout = Row;
using B0Layout = Col;
using ELayout = Row;
using D0Layout = Row;
using D1Layout = Col;
using D2Layout = ELayout;
using DsLayout = ck::Tuple<D0Layout, D1Layout, D2Layout>;
// d0: ascale, d1: bscale, d2:expert weight
struct MulABScaleExpertWeight
{
template <typename E, typename C, typename D0, typename D1, typename D2>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const;
// for real kernel use
template <>
__host__ __device__ constexpr void operator()<EDataType, F16, float, float, float>(
EDataType& e, const F16& c, const float& d0, const float& d1, const float& d2) const
{
(void)d0;
(void)d1;
(void)d2;
e = ck::type_convert<EDataType>(c);
}
// for reference cpu
template <>
__host__ __device__ constexpr void operator()<float, float, float, float, float>(
float& e, const float& c, const float& d0, const float& d1, const float& d2) const
{
// for reference cpu
e = ck::type_convert<EDataType>(c * d0 * d1 * d2);
}
};
using CDEElementOp = MulABScaleExpertWeight;
// A, B Scale preshuffle
template <bool KLast>
void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K)
{
int MNXdlPack = 2;
int KXdlPack = 2;
int XdlMNThread = 16;
int XdlKThread = 64 / XdlMNThread;
int K0 = K / KXdlPack / XdlKThread; // KRepeat
// The 4 16x128 building blocks will be packed into 1 32x256 for F4
// The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4
// unfold the MN32xK(256/32) scale buffer
// 4 16 2 2
// To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack
// Then, MNRepeat->KRepeat
for(int n = 0; n < MN; ++n)
{
for(int k = 0; k < K; ++k)
{
int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat
int tempn = n % (XdlMNThread * MNXdlPack);
int n1 = tempn % XdlMNThread; // i XdlMNThread
int n2 = tempn / XdlMNThread; // i MNXdlPack
int k0 = k / (XdlKThread * KXdlPack); // i KRepeat
int tempk = k % (XdlKThread * KXdlPack);
int k1 = tempk % XdlKThread; // i XdlKThread
int k2 = tempk / XdlKThread; // i KXdlPack
int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 +
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack +
k2 * MNXdlPack + n2;
// src[n * K + k] = ck::type_convert<ck::e8m0_bexp_t>(static_cast<float>(powf(2.0f, n2 +
// k2 * MNXdlPack)));
if constexpr(KLast)
dst[outputIndex] = src[n * K + k];
else
dst[outputIndex] = src[k * MN + n];
}
}
}
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = MulABScaleExpertWeight;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
constexpr ck::index_t DataPackedSize = 2; // Packed representation of data
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
static constexpr ck::index_t MPerBlock = 128;
static constexpr bool MulRoutedWeight = true;
// clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMX<
A0Layout, B0Layout, DsLayout, ELayout,
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
ScaleBlockSize, 256,
MPerBlock, 128, KPerBlock,
16, 16,
16, 16,
4, 4,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
2, 4, S<1, 4, 1, 64>, S<2, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>;
// clang-format on
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
// per expert:
// GEMM shape
constexpr ck::index_t sorted_tile_num = 13;
constexpr ck::index_t valid_tile_num = sorted_tile_num;
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;
ck::index_t N = 6144;
ck::index_t K = 4096;
ck::index_t experts = 8;
ck::index_t tokens = 832;
ck::index_t topk = 2;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
// use default case
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 7)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
N = std::stoi(argv[4]);
K = std::stoi(argv[5]);
tokens = std::stoi(argv[6]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 6: N, K, tokens\n");
exit(0);
}
if(K % ScaleBlockSize != 0)
{
throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize.");
};
ck::index_t StrideA = K;
ck::index_t StrideB = K;
ck::index_t StrideE = N;
ck::index_t Scale_Stride_AM = (K + ScaleBlockSize - 1) / ScaleBlockSize;
ck::index_t Scale_Stride_BN = (K + ScaleBlockSize - 1) / ScaleBlockSize;
constexpr ck::index_t NumDTensor = DsDataType::Size();
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0, 0};
ck::index_t KBatch = 1;
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1}));
max_token_id.mData[0] = valid_size;
// int eids[] = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 3, 3, 3};
int eids[sorted_tile_num]{};
for(int i = 0; i < sorted_tile_num; i++)
{
if(i < valid_tile_num)
{
eids[i] = (i * experts) / valid_tile_num;
}
else
{
eids[i] = 3;
}
}
for(int i = 0; i < sorted_tile_num; i++)
{
expert_ids.mData[i] = eids[i];
}
if(tokens * topk > valid_size)
{
printf("err config, tokens * topk > valid_size\n");
exit(-1);
}
int token_per_tile = tokens * topk / valid_tile_num;
int tokenid = 0;
for(int i = 0; i < sorted_size; i++)
{
int tile_off = i % MPerBlock;
if(tile_off < token_per_tile)
{
sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24);
tokenid++;
}
else
{
sorted_token_ids.mData[i] = tokens;
}
}
expert_ids.savetxt("expert_ids.txt", "int");
sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
Tensor<A0DataType> a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1}));
Tensor<XDataType> a1_t_k_k(
HostTensorDescriptor({tokens, topk, (K + ScaleBlockSize - 1) / ScaleBlockSize},
{(topk * Scale_Stride_AM), Scale_Stride_AM, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
Tensor<XDataType> b1_e_n_k(
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N},
{(N * Scale_Stride_BN), 1, Scale_Stride_BN}));
// A, B Scale preshuffle
Tensor<XDataType> a_scale_sorted(HostTensorDescriptor(
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
Tensor<XDataType> a_scale_preshuffled(HostTensorDescriptor(
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
Tensor<XDataType> b_scale_preshuffled(
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N},
{N * Scale_Stride_BN, 1, Scale_Stride_BN}));
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1}));
Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1}));
e_t_n_device_result.SetZero();
std::cout << "a0_t_k_k: " << a0_t_k_k.mDesc << std::endl;
std::cout << "a1_t_k_k: " << a1_t_k_k.mDesc << std::endl;
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl;
std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl;
std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-1, 1});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-1, 1});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0, 1.0});
break;
case 2:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 3:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 4:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 5.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 5:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 6:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 7:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 8:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
default:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
}
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.GetElementSpaceSize());
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.GetElementSpaceSize());
DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.GetElementSpaceSize());
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k_k.GetElementSpaceSize());
DeviceMem a1_device_buf(sizeof(XDataType) * a_scale_sorted.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.GetElementSpaceSize());
DeviceMem b1_device_buf(sizeof(XDataType) * b1_e_n_k.GetElementSpaceSize());
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.GetElementSpaceSize());
// d2_e_n.savetxt("weight.txt", "int");
// A scale sorted
for(int i = 0; i < sorted_size; i++)
{
int token_id = sorted_token_ids.mData[i] & 0x00FFFFFF;
int topk_id = (sorted_token_ids.mData[i] >> 24) & 0x000000FF;
for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; k++)
{
if(token_id == tokens)
{
a_scale_sorted(i, k) = ck::type_convert<XDataType>(0);
}
else
{
a_scale_sorted(i, k) = a1_t_k_k(token_id, topk_id, k);
}
}
}
preShuffleScaleBuffer<ck::is_same_v<A0Layout, Row>>(a_scale_sorted.mData.data(),
a_scale_preshuffled.mData.data(),
sorted_size,
K / ScaleBlockSize);
preShuffleScaleBuffer<ck::is_same_v<B0Layout, Col>>(
b1_e_n_k.mData.data(), b_scale_preshuffled.mData.data(), N * experts, K / ScaleBlockSize);
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
expert_ids_dev.ToDevice(expert_ids.mData.data());
max_token_id_dev.ToDevice(max_token_id.mData.data());
a0_device_buf.ToDevice(a0_t_k_k.mData.data());
b0_device_buf.ToDevice(b0_e_n_k.mData.data());
a1_device_buf.ToDevice(a_scale_preshuffled.mData.data());
b1_device_buf.ToDevice(b_scale_preshuffled.mData.data());
d2_device_buf.ToDevice(d2_e_n.mData.data());
e_device_buf.ToDevice(e_t_n_device_result.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
// do GEMM
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument = device_op.MakeArgument(
sorted_token_ids_dev.GetDeviceBuffer(),
expert_ids_dev.GetDeviceBuffer(),
max_token_id_dev.GetDeviceBuffer(),
a0_device_buf.GetDeviceBuffer(),
a1_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(),
b1_device_buf.GetDeviceBuffer(),
std::array<const void*, NumDTensor>{nullptr, nullptr, d2_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
tokens,
topk,
sorted_size,
N,
K,
StrideA,
Scale_Stride_AM,
StrideB,
Scale_Stride_BN,
StrideDs,
StrideE,
KBatch,
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"))
{
std::cout << "This kernel support gfx942 and gfx950 only" << std::endl;
}
if(time_kernel)
{
// not result correct here because output buf not setzero
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
// FMA * tokens * N * topk * K +
// FMA * tokens * N * topk * (K/BlockScale)
std::size_t flop = std::size_t(2) * tokens * topk * N * K +
std::size_t(2) * tokens * topk * N * K / ScaleBlockSize;
std::size_t num_btype =
sizeof(A0DataType) / 2 * tokens * K * topk + sizeof(B0DataType) / 2 * K * N * experts +
sizeof(XDataType) * tokens * topk * K / ScaleBlockSize +
sizeof(XDataType) * K / ScaleBlockSize * N * experts + sizeof(EDataType) * tokens * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << device_op.GetTypeString() << std::endl;
}
if(do_verification)
{
// gemm2 use atomic, so need to reinit outputs
e_device_buf.ToDevice(e_t_n_device_result.mData.data());
invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1});
Tensor<float> c_t_n({tokens, N});
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceMoeMXGemm2<A0DataType,
XDataType,
B0DataType,
XDataType,
D2DataType,
float, // using float for Cshuffle type
// in reference
AccDataType,
PassThrough,
PassThrough,
CDEElementOp,
MulRoutedWeight,
float,
float>;
auto ref_moe_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_moe_gemm.MakeInvoker();
auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids,
expert_ids,
max_token_id,
MPerBlock,
a0_t_k_k,
a1_t_k_k,
b0_e_n_k,
b1_e_n_k,
d2_e_n, // topk weights
c_t_n,
PassThrough{},
PassThrough{},
cde_element_op);
ref_invoker.Run(ref_argument);
for(int t = 0; t < tokens; ++t)
{
for(int n = 0; n < N; ++n)
{
e_t_n_host_result(t, n) = ck::type_convert<EDataType>(c_t_n(t, n));
}
}
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
return ck::utils::check_err(
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2)
? 0
: 1;
}
return 0;
}

View File

@@ -158,7 +158,7 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
4, 4,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
2, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
2, 4, S<1, 4, 1, 64>, S<2, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>;
// clang-format on

View File

@@ -0,0 +1,584 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bpreshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm2.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F4 = ck::f4x2_pk_t;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using XDataType = ck::e8m0_bexp_t;
using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t
using I64 = int64_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = F4;
using A1DataType = XPackedDataType;
using B0DataType = F4;
using B1DataType = XPackedDataType;
using EDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F16;
using D0DataType = F32;
using D1DataType = F32;
using D2DataType = F32;
using DsDataType = ck::Tuple<D0DataType, D1DataType, D2DataType>;
using A0Layout = Row;
using B0Layout = Col;
using ELayout = Row;
using D0Layout = Row;
using D1Layout = Col;
using D2Layout = ELayout;
using DsLayout = ck::Tuple<D0Layout, D1Layout, D2Layout>;
// d0: ascale, d1: bscale, d2:expert weight
struct MulABScaleExpertWeight
{
template <typename E, typename C, typename D0, typename D1, typename D2>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const;
// for real kernel use
template <>
__host__ __device__ constexpr void operator()<EDataType, F16, float, float, float>(
EDataType& e, const F16& c, const float& d0, const float& d1, const float& d2) const
{
(void)d0;
(void)d1;
(void)d2;
e = ck::type_convert<EDataType>(c);
}
// for reference cpu
template <>
__host__ __device__ constexpr void operator()<float, float, float, float, float>(
float& e, const float& c, const float& d0, const float& d1, const float& d2) const
{
// for reference cpu
e = ck::type_convert<EDataType>(c * d0 * d1 * d2);
}
};
using CDEElementOp = MulABScaleExpertWeight;
// B preshuffle
void preShuffleBuffer(const F4* src, F4* dst, int N, int K, int NXdl)
{
int KPack = 16;
int NLane = NXdl;
int KLane = 64 / NLane;
int K_pk = K / 2;
int K0 = K_pk / (KLane * KPack);
// K -> K0 KLane KPack
// N -> N0 NLane
// N, K -> N0 K0 KLane NLane KPack
I64 tempk;
for(I64 n = 0; n < N; ++n)
{
for(I64 k = 0; k < K_pk; ++k)
{
I64 n0 = n / NLane;
I64 n1 = n % NLane;
I64 k0 = k / (KLane * KPack);
tempk = k % (KLane * KPack);
I64 k1 = tempk / KPack;
I64 k2 = tempk % KPack;
I64 outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
k1 * KPack * NLane + n1 * KPack + k2;
dst[outputIndex] = src[n * K_pk + k];
}
}
}
// A, B Scale preshuffle
template <bool KLast>
void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K)
{
int MNXdlPack = 2;
int KXdlPack = 2;
int XdlMNThread = 16;
int XdlKThread = 64 / XdlMNThread;
int K0 = K / KXdlPack / XdlKThread; // KRepeat
// The 4 16x128 building blocks will be packed into 1 32x256 for F4
// The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4
// unfold the MN32xK(256/32) scale buffer
// 4 16 2 2
// To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack
// Then, MNRepeat->KRepeat
for(int n = 0; n < MN; ++n)
{
for(int k = 0; k < K; ++k)
{
int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat
int tempn = n % (XdlMNThread * MNXdlPack);
int n1 = tempn % XdlMNThread; // i XdlMNThread
int n2 = tempn / XdlMNThread; // i MNXdlPack
int k0 = k / (XdlKThread * KXdlPack); // i KRepeat
int tempk = k % (XdlKThread * KXdlPack);
int k1 = tempk % XdlKThread; // i XdlKThread
int k2 = tempk / XdlKThread; // i KXdlPack
int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 +
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack +
k2 * MNXdlPack + n2;
// src[n * K + k] = ck::type_convert<ck::e8m0_bexp_t>(static_cast<float>(powf(2.0f, n2 +
// k2 * MNXdlPack)));
if constexpr(KLast)
dst[outputIndex] = src[n * K + k];
else
dst[outputIndex] = src[k * MN + n];
}
}
}
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = MulABScaleExpertWeight;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
constexpr ck::index_t DataPackedSize = 2; // Packed representation of data
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
static constexpr ck::index_t MPerBlock = 128;
static constexpr bool MulRoutedWeight = true;
// clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMXBPreShuffle<
A0Layout, B0Layout, DsLayout, ELayout,
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
ScaleBlockSize, 256,
MPerBlock, 128, KPerBlock,
16, 16,
16, 16,
8, 2,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
2, 2, S<1, 4, 1, 64>, S<2, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>;
// clang-format on
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
// per expert:
// GEMM shape
constexpr ck::index_t sorted_tile_num = 13;
constexpr ck::index_t valid_tile_num = 13;
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;
ck::index_t N = 6144;
ck::index_t K = 4096;
ck::index_t experts = 8;
ck::index_t tokens = 832;
ck::index_t topk = 2;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
// use default case
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 7)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
N = std::stoi(argv[4]);
K = std::stoi(argv[5]);
tokens = std::stoi(argv[6]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 6: N, K, tokens\n");
exit(0);
}
if(K % ScaleBlockSize != 0)
{
throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize.");
};
ck::index_t StrideA = K;
ck::index_t StrideB = K;
ck::index_t StrideE = N;
ck::index_t Scale_Stride_AM = (K + ScaleBlockSize - 1) / ScaleBlockSize;
ck::index_t Scale_Stride_BN = (K + ScaleBlockSize - 1) / ScaleBlockSize;
constexpr ck::index_t NumDTensor = DsDataType::Size();
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0, 0};
ck::index_t KBatch = 1;
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1}));
max_token_id.mData[0] = valid_size;
// int eids[] = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 3, 3, 3};
int eids[sorted_tile_num]{};
for(int i = 0; i < sorted_tile_num; i++)
{
if(i < valid_tile_num)
{
eids[i] = (i * experts) / valid_tile_num;
}
else
{
eids[i] = 3;
}
}
for(int i = 0; i < sorted_tile_num; i++)
{
expert_ids.mData[i] = eids[i];
}
if(tokens * topk > valid_size)
{
printf("err config, tokens * topk > valid_size\n");
exit(-1);
}
int token_per_tile = tokens * topk / valid_tile_num;
int tokenid = 0;
for(int i = 0; i < sorted_size; i++)
{
int tile_off = i % MPerBlock;
if(tile_off < token_per_tile)
{
sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24);
tokenid++;
}
else
{
sorted_token_ids.mData[i] = tokens;
}
}
expert_ids.savetxt("expert_ids.txt", "int");
sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
Tensor<A0DataType> a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1}));
Tensor<XDataType> a1_t_k_k(
HostTensorDescriptor({tokens, topk, (K + ScaleBlockSize - 1) / ScaleBlockSize},
{(topk * Scale_Stride_AM), Scale_Stride_AM, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
Tensor<XDataType> b1_e_n_k(
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N},
{(N * Scale_Stride_BN), 1, Scale_Stride_BN}));
// B preshuffle
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
// A, B Scale preshuffle
Tensor<XDataType> a_scale_sorted(HostTensorDescriptor(
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
Tensor<XDataType> a_scale_preshuffled(HostTensorDescriptor(
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
Tensor<XDataType> b_scale_preshuffled(
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N},
{N * Scale_Stride_BN, 1, Scale_Stride_BN}));
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1}));
Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1}));
e_t_n_device_result.SetZero();
std::cout << "a0_t_k_k: " << a0_t_k_k.mDesc << std::endl;
std::cout << "a1_t_k_k: " << a1_t_k_k.mDesc << std::endl;
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl;
std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl;
std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-1, 1});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-1, 1});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0, 1.0});
break;
case 2:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 3:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 4:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 5.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 5:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 6:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 7:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 8:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
default:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
}
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.GetElementSpaceSize());
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.GetElementSpaceSize());
DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.GetElementSpaceSize());
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k_k.GetElementSpaceSize());
DeviceMem a1_device_buf(sizeof(XDataType) * a_scale_sorted.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.GetElementSpaceSize());
DeviceMem b1_device_buf(sizeof(XDataType) * b1_e_n_k.GetElementSpaceSize());
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.GetElementSpaceSize());
// A scale sorted
for(int i = 0; i < sorted_size; i++)
{
int token_id = sorted_token_ids.mData[i] & 0x00FFFFFF;
int topk_id = (sorted_token_ids.mData[i] >> 24) & 0x000000FF;
for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; k++)
{
if(token_id == tokens)
{
a_scale_sorted(i, k) = ck::type_convert<XDataType>(0);
}
else
{
a_scale_sorted(i, k) = a1_t_k_k(token_id, topk_id, k);
}
}
}
// A, B Scale preshuffle
preShuffleScaleBuffer<ck::is_same_v<A0Layout, Row>>(a_scale_sorted.mData.data(),
a_scale_preshuffled.mData.data(),
sorted_size,
K / ScaleBlockSize);
preShuffleScaleBuffer<ck::is_same_v<B0Layout, Col>>(
b1_e_n_k.mData.data(), b_scale_preshuffled.mData.data(), N * experts, K / ScaleBlockSize);
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
expert_ids_dev.ToDevice(expert_ids.mData.data());
max_token_id_dev.ToDevice(max_token_id.mData.data());
a0_device_buf.ToDevice(a0_t_k_k.mData.data());
a1_device_buf.ToDevice(a_scale_preshuffled.mData.data());
b1_device_buf.ToDevice(b_scale_preshuffled.mData.data());
d2_device_buf.ToDevice(d2_e_n.mData.data());
e_device_buf.ToDevice(e_t_n_device_result.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
// do GEMM
auto device_op = DeviceOpInstance{};
preShuffleBuffer(b0_e_n_k.mData.data(),
b0_preshuffled.mData.data(),
N * experts,
K,
device_op.GetPreShuffleParameters());
b0_device_buf.ToDevice(b0_preshuffled.mData.data());
auto invoker = device_op.MakeInvoker();
auto argument = device_op.MakeArgument(
sorted_token_ids_dev.GetDeviceBuffer(),
expert_ids_dev.GetDeviceBuffer(),
max_token_id_dev.GetDeviceBuffer(),
a0_device_buf.GetDeviceBuffer(),
a1_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(),
b1_device_buf.GetDeviceBuffer(),
std::array<const void*, NumDTensor>{nullptr, nullptr, d2_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
tokens,
topk,
sorted_size,
N,
K,
StrideA,
Scale_Stride_AM,
StrideB,
Scale_Stride_BN,
StrideDs,
StrideE,
KBatch,
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"))
{
std::cout << "This kernel support gfx942 and gfx950 only" << std::endl;
}
if(time_kernel)
{
// not result correct here because output buf not setzero
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
// FMA * tokens * N * topk * K +
// FMA * tokens * N * topk * (K/BlockScale)
std::size_t flop = std::size_t(2) * tokens * topk * N * K +
std::size_t(2) * tokens * topk * N * K / ScaleBlockSize;
std::size_t num_btype =
sizeof(A0DataType) / 2 * tokens * K * topk + sizeof(B0DataType) / 2 * K * N * experts +
sizeof(XDataType) * tokens * topk * K / ScaleBlockSize +
sizeof(XDataType) * K / ScaleBlockSize * N * experts + sizeof(EDataType) * tokens * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << device_op.GetTypeString() << std::endl;
}
if(do_verification)
{
// gemm2 use atomic, so need to reinit outputs
e_device_buf.ToDevice(e_t_n_device_result.mData.data());
invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1});
Tensor<float> c_t_n({tokens, N});
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceMoeMXGemm2<A0DataType,
XDataType,
B0DataType,
XDataType,
D2DataType,
float, // using float for Cshuffle type
// in reference
AccDataType,
PassThrough,
PassThrough,
CDEElementOp,
MulRoutedWeight,
float,
float>;
auto ref_moe_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_moe_gemm.MakeInvoker();
auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids,
expert_ids,
max_token_id,
MPerBlock,
a0_t_k_k,
a1_t_k_k,
b0_e_n_k,
b1_e_n_k,
d2_e_n, // topk weights
c_t_n,
PassThrough{},
PassThrough{},
cde_element_op);
ref_invoker.Run(ref_argument);
for(int t = 0; t < tokens; ++t)
{
for(int n = 0; n < N; ++n)
{
e_t_n_host_result(t, n) = ck::type_convert<EDataType>(c_t_n(t, n));
}
}
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
return ck::utils::check_err(
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2)
? 0
: 1;
}
return 0;
}

View File

@@ -1,919 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp"
namespace ck {
// Naive pipeline with lowest resource request per WGP
// GlobalPrefetchStages: 2
// LocalPreFillStages: 1
// LocalPreFetchStages: 1
// LocalSharedMemoryBuffer: 1
template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
index_t ThreadBlockSize,
index_t ScaleBlockSize,
typename ADataType,
typename AScaleDataType,
typename BDataType,
typename BScaleDataType,
typename ATileDesc,
typename BTileDesc,
typename AMmaTileDesc,
typename BMmaTileDesc,
index_t ABlockTransferSrcScalarPerVector,
index_t BBlockTransferSrcScalarPerVector,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat, // MXdlPerWave
index_t NRepeat, // NXdlPerWave
index_t KPack>
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v1
{
};
template <index_t ThreadBlockSize,
index_t ScaleBlockSize,
typename ADataType,
typename AScaleDataType,
typename BDataType,
typename BScaleDataType,
typename ATileDesc,
typename BTileDesc,
typename AMmaTileDesc,
typename BMmaTileDesc,
index_t ABlockTransferSrcScalarPerVector,
index_t BBlockTransferSrcScalarPerVector,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat, // MXdlPerWave
index_t NRepeat, // NXdlPerWave
index_t KPack>
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v1<
BlockGemmPipelineScheduler::Intrawave,
ThreadBlockSize,
ScaleBlockSize,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack> : BlockwiseGemmXdlops_mx_pipeline_base<ThreadBlockSize,
ADataType,
BDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>
{
using Base = BlockwiseGemmXdlops_mx_pipeline_base<ThreadBlockSize,
ADataType,
BDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>;
using Base::I0;
using Base::I1;
using Base::KRepeat;
using Base::MWaves;
using Base::NWaves;
using Base::WaveSize;
using Base::xdlops_gemm;
using Base::CalculateCThreadOriginDataIndex;
using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
using Base::GetCThreadBuffer;
using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
using Base::GetWaveIdx;
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::a_block_desc_m0_m1_m2_k;
using Base::b_block_desc_n0_n1_n2_k;
using Base::AMmaKStride;
using Base::BMmaKStride;
using Base::KThreadChunk;
using Base::APackedSize;
using Base::BPackedSize;
using Base::ComputePackedSize;
using AccType = typename Base::AccType;
using Tuple4 = typename Base::Tuple4;
using ComputeTypeA = typename Base::ComputeTypeA;
using ComputeTypeB = typename Base::ComputeTypeB;
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 2;
template <typename TileDesc_M0_M1_M2_K>
__host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&)
{
constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
constexpr index_t K2 = KPack;
constexpr index_t K1 = 64 / NPerXDL;
constexpr index_t K0 = KRepeat;
return transform_tensor_descriptor(
TileDesc_M0_M1_M2_K{},
make_tuple(
make_pass_through_transform(Number<M0>{}),
make_pass_through_transform(Number<M1>{}),
make_pass_through_transform(Number<M2>{}),
make_unmerge_transform(make_tuple(Number<K0>{}, Number<K1>{}, Number<K2>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}));
}
static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 =
MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k);
static constexpr auto ScalesPerKBlockSize =
KPerBlock / ScaleBlockSize; // How many mx-vectors per K block
//> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run()
static constexpr auto ScalesPerXdlopsRun = (KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize;
//> How many scales a thread must read to accommodate one call to xdlops_gemm.Run()
static constexpr auto ScalesPerXdlopsRunPerThread =
ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks;
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
{
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
}
template <bool HasMainLoop,
TailNumber TailNum,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer,
typename AScaleGridBuffer,
typename AScaleGridDesc,
typename AScaleThreadTransfer,
typename BScaleGridBuffer,
typename BScaleGridDesc,
typename BScaleThreadTransfer>
__device__ void Run(
// ABlockCopy
const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
// BBlockCopy
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
BBlockTransfer& b_blockwise_copy_up,
const BGridBuffer& b_grid_buf,
const BGridBuffer& b_grid_buf_up,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
// CThread
CThreadBuffer& c_thread_buf,
CThreadBuffer& c_thread_buf_up,
// A and B scales
const AScaleGridDesc& a_scale_grid_desc,
AScaleThreadTransfer& a_scale_thread_copy,
const AScaleGridBuffer& a_scale_grid_buf,
const BScaleGridDesc& b_scale_grid_desc,
BScaleThreadTransfer& b_scale_thread_copy,
BScaleThreadTransfer& b_scale_thread_copy_up,
const BScaleGridBuffer& b_scale_grid_buf,
const BScaleGridBuffer& b_scale_grid_buf_up,
index_t num_loop) const
{
ignore = b_block_desc;
ignore = b_block_buf;
ignore = a_scale_grid_buf;
ignore = b_scale_grid_buf;
ignore = b_scale_grid_buf_up;
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
b_thread_desc_.GetElementSpaceSize());
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs_up;
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
a_scale_thread_desc.GetElementSpaceSize());
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
b_scale_thread_desc.GetElementSpaceSize());
StaticallyIndexedArray<decltype(a_scale_thread_buf), Number<2>{}> a_scale_thread_bufs;
StaticallyIndexedArray<decltype(b_scale_thread_buf), Number<2>{}> b_scale_thread_bufs;
StaticallyIndexedArray<decltype(b_scale_thread_buf), Number<2>{}> b_scale_thread_bufs_up;
// Global prefetch A1 B1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(I0));
b_blockwise_copy_up.Run(b_grid_desc,
b_grid_buf_up,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs_up(I0));
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Prefetch a_scales to buf 0
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc,
make_tuple(I0, I0, I0),
a_scale_thread_bufs(I0));
// restore row id and advance to the next set of scales
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
make_multi_index(0, ScalesPerKBlockSize, 0));
// Prefetch b_scales to buf 0
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
constexpr auto b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
auto b_scale_thread_buf_copy =
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
b_scale_thread_desc_copy.GetElementSpaceSize());
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc_copy,
make_tuple(I0, I0),
b_scale_thread_buf_copy);
b_scale_thread_bufs(I0)(Number<b_scale_offset>{}) =
b_scale_thread_buf_copy[Number<0>{}];
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc,
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
auto b_scale_thread_buf_copy_up =
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
b_scale_thread_desc_copy.GetElementSpaceSize());
b_scale_thread_copy_up.Run(b_scale_grid_desc,
b_scale_grid_buf_up,
b_scale_thread_desc_copy,
make_tuple(I0, I0),
b_scale_thread_buf_copy_up);
b_scale_thread_bufs_up(I0)(Number<b_scale_offset>{}) =
b_scale_thread_buf_copy_up[Number<0>{}];
b_scale_thread_copy_up.MoveSrcSliceWindow(
b_scale_grid_desc,
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
});
});
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
b_scale_thread_copy_up.MoveSrcSliceWindow(
b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
});
// restore col id and advance to the next set of scales
// NWaves * NPerXDL * NRepeat == NPerBlock
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
make_multi_index(-NPerBlock, ScalesPerKBlockSize));
b_scale_thread_copy_up.MoveSrcSliceWindow(
b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize));
__builtin_amdgcn_sched_barrier(0);
// Local prefill A1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
// Global prefetch A2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
// Prefetch a_scales to buf 1
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc,
make_tuple(I0, I0, I0),
a_scale_thread_bufs(I1));
// restore row id and advance to the next set of scales
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
make_multi_index(0, ScalesPerKBlockSize, 0));
// Prefetch b_scales to buf 1
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
constexpr auto b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
auto b_scale_thread_buf_copy =
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
b_scale_thread_desc_copy.GetElementSpaceSize());
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc_copy,
make_tuple(I0, I0),
b_scale_thread_buf_copy);
b_scale_thread_bufs(I1)(Number<b_scale_offset>{}) =
b_scale_thread_buf_copy[Number<0>{}];
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc,
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
auto b_scale_thread_buf_copy_up =
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
b_scale_thread_desc_copy.GetElementSpaceSize());
b_scale_thread_copy_up.Run(b_scale_grid_desc,
b_scale_grid_buf_up,
b_scale_thread_desc_copy,
make_tuple(I0, I0),
b_scale_thread_buf_copy_up);
b_scale_thread_bufs_up(I1)(Number<b_scale_offset>{}) =
b_scale_thread_buf_copy_up[Number<0>{}];
b_scale_thread_copy_up.MoveSrcSliceWindow(
b_scale_grid_desc,
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
});
});
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
b_scale_thread_copy_up.MoveSrcSliceWindow(
b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
});
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
make_multi_index(-NPerBlock, ScalesPerKBlockSize));
b_scale_thread_copy_up.MoveSrcSliceWindow(
b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize));
// Local prefetch A1
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step = k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
});
});
// Initialize C
c_thread_buf.Clear();
c_thread_buf_up.Clear();
// main body
if constexpr(HasMainLoop)
{
// loop over k with the step KPerBlock
index_t i = 0;
do
{
auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(local_read_buf));
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
b_blockwise_copy_up.Run(b_grid_desc,
b_grid_buf_up,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs_up(local_read_buf));
b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf);
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec_up;
static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs_up[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
static_assert(
0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops per Thread.");
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread>
a_scale_thread_vec;
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread>
b_scale_thread_vec;
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread>
b_scale_thread_vec_up;
// Pack scale_thread_buf into scale_thread_vec
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs[mfma_reg_buf]
[Number<a_scale_offset + s>{}];
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs[mfma_reg_buf]
[Number<b_scale_offset + s>{}];
b_scale_thread_vec_up.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs_up[mfma_reg_buf]
[Number<b_scale_offset + s>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops /
APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops /
BPackedSize>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
// MFMA accumulation
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<AScaleDataType>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<BScaleDataType>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<AScaleDataType>(),
b_thread_vec_up.template AsType<mfma_input_type_b>(),
b_scale_thread_vec_up.template AsType<BScaleDataType>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
block_sync_lds();
// a thread copy
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step =
k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}(
[&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step + chunk * KThreadChunk *
xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
});
});
// Prefetch a_scales
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc,
make_tuple(I0, I0, I0),
a_scale_thread_bufs(mfma_reg_buf));
// restore row id and advance to the next set of scales
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, make_multi_index(0, ScalesPerKBlockSize, 0));
// Prefetch b_scales
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
constexpr auto b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
auto b_scale_thread_buf_copy =
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
b_scale_thread_desc_copy.GetElementSpaceSize());
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc_copy,
make_tuple(I0, I0),
b_scale_thread_buf_copy);
b_scale_thread_bufs(mfma_reg_buf)(Number<b_scale_offset>{}) =
b_scale_thread_buf_copy[Number<0>{}];
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc,
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
auto b_scale_thread_buf_copy_up =
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
b_scale_thread_desc_copy.GetElementSpaceSize());
b_scale_thread_copy_up.Run(b_scale_grid_desc,
b_scale_grid_buf_up,
b_scale_thread_desc_copy,
make_tuple(I0, I0),
b_scale_thread_buf_copy_up);
b_scale_thread_bufs_up(mfma_reg_buf)(Number<b_scale_offset>{}) =
b_scale_thread_buf_copy_up[Number<0>{}];
b_scale_thread_copy_up.MoveSrcSliceWindow(
b_scale_grid_desc,
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
});
});
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc,
make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
b_scale_thread_copy_up.MoveSrcSliceWindow(
b_scale_grid_desc,
make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
});
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize));
b_scale_thread_copy_up.MoveSrcSliceWindow(
b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize));
};
LoopFunc(I0, I1);
LoopFunc(I1, I0);
i += 2;
} while(i < (num_loop - 2));
}
// tail
if constexpr(TailNum == TailNumber::Even)
{
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(I1));
b_blockwise_copy_up.Run(b_grid_desc,
b_grid_buf_up,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs_up(I1));
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec_up;
static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread> a_scale_thread_vec;
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread> b_scale_thread_vec;
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread>
b_scale_thread_vec_up;
// Pack b_scale_thread_buf into b_scale_thread_vec
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs[I0][Number<a_scale_offset + s>{}];
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs[I0][Number<b_scale_offset + s>{}];
b_scale_thread_vec_up.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs_up[I0][Number<b_scale_offset + s>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
// MFMA accumulation
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<AScaleDataType>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<BScaleDataType>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<AScaleDataType>(),
b_thread_vec_up.template AsType<mfma_input_type_b>(),
b_scale_thread_vec_up.template AsType<BScaleDataType>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
block_sync_lds();
// a thread copy
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step =
k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
});
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec_up;
static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs_up[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread> a_scale_thread_vec;
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread> b_scale_thread_vec;
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread>
b_scale_thread_vec_up;
// Pack b_scale_thread_buf into b_scale_thread_vec
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs[I1][Number<a_scale_offset + s>{}];
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs[I1][Number<b_scale_offset + s>{}];
b_scale_thread_vec_up.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs_up[I1][Number<b_scale_offset + s>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
// MFMA accumulation
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<AScaleDataType>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<BScaleDataType>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<AScaleDataType>(),
b_thread_vec_up.template AsType<mfma_input_type_b>(),
b_scale_thread_vec_up.template AsType<BScaleDataType>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
else if constexpr(TailNum == TailNumber::Odd)
{
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec_up;
static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
b_thread_vec_up.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread> a_scale_thread_vec;
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread> b_scale_thread_vec;
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread>
b_scale_thread_vec_up;
// Pack b_scale_thread_buf into b_scale_thread_vec
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs[I0][Number<a_scale_offset + s>{}];
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs[I0][Number<b_scale_offset + s>{}];
b_scale_thread_vec_up.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs_up[I0][Number<b_scale_offset + s>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
// MFMA accumulation
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<AScaleDataType>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<BScaleDataType>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<AScaleDataType>(),
b_thread_vec_up.template AsType<mfma_input_type_b>(),
b_scale_thread_vec_up.template AsType<BScaleDataType>(),
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
}
// TODO: make this field protected when a_scale_thread_copy_ is moved
// here
static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<KRepeat>{}, Number<ScalesPerXdlopsRunPerThread>{}));
// Is used to copy data from a_scale_grid to a_scale_thread
static constexpr auto a_scale_thread_desc_copy =
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{}));
// TODO: make this field protected when b_scale_thread_copy_ is moved
// here
static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<NRepeat>{}, Number<KRepeat>{}, Number<ScalesPerXdlopsRunPerThread>{}));
// Is used to copy data from b_scale_grid to b_scale_thread_buf
static constexpr auto b_scale_thread_desc_copy =
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{}));
protected:
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}));
using Base::a_thread_copy_;
using Base::a_thread_desc_;
using Base::b_thread_copy_;
// using Base::b_thread_desc_;
using Base::c_thread_desc_;
static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
};
} // namespace ck

View File

@@ -3,8 +3,6 @@
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v1.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp"
@@ -43,54 +41,11 @@ constexpr auto BlockGemmMXBPreshufflePipeline_Selector()
{
if constexpr(GUFusion)
{
return BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v1<
BlkGemmPipeSche,
ThreadBlockSize,
ScaleBlockSize,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
;
return nullptr;
}
else
{
return BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1<
BlkGemmPipeSche,
ThreadBlockSize,
ScaleBlockSize,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
return nullptr;
}
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)

View File

@@ -1,813 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp"
namespace ck {
// Naive pipeline with lowest resource request per WGP
// GlobalPrefetchStages: 2
// LocalPreFillStages: 1
// LocalPreFetchStages: 1
// LocalSharedMemoryBuffer: 1
template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
index_t ThreadBlockSize,
index_t ScaleBlockSize,
typename ADataType,
typename AScaleDataType,
typename BDataType,
typename BScaleDataType,
typename ATileDesc,
typename BTileDesc,
typename AMmaTileDesc,
typename BMmaTileDesc,
index_t ABlockTransferSrcScalarPerVector,
index_t BBlockTransferSrcScalarPerVector,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat, // MXdlPerWave
index_t NRepeat, // NXdlPerWave
index_t KPack>
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1
{
};
template <index_t ThreadBlockSize,
index_t ScaleBlockSize,
typename ADataType,
typename AScaleDataType,
typename BDataType,
typename BScaleDataType,
typename ATileDesc,
typename BTileDesc,
typename AMmaTileDesc,
typename BMmaTileDesc,
index_t ABlockTransferSrcScalarPerVector,
index_t BBlockTransferSrcScalarPerVector,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat, // MXdlPerWave
index_t NRepeat, // NXdlPerWave
index_t KPack>
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1<BlockGemmPipelineScheduler::Intrawave,
ThreadBlockSize,
ScaleBlockSize,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>
: BlockwiseGemmXdlops_mx_pipeline_base<ThreadBlockSize,
ADataType,
BDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>
{
using Base = BlockwiseGemmXdlops_mx_pipeline_base<ThreadBlockSize,
ADataType,
BDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>;
using Base::I0;
using Base::I1;
using Base::KRepeat;
using Base::MWaves;
using Base::NWaves;
using Base::WaveSize;
using Base::xdlops_gemm;
using Base::CalculateCThreadOriginDataIndex;
using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
using Base::GetCThreadBuffer;
using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
using Base::GetWaveIdx;
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::a_block_desc_m0_m1_m2_k;
using Base::b_block_desc_n0_n1_n2_k;
using Base::AMmaKStride;
using Base::BMmaKStride;
using Base::KThreadChunk;
using Base::APackedSize;
using Base::BPackedSize;
using Base::ComputePackedSize;
using AccType = typename Base::AccType;
using Tuple4 = typename Base::Tuple4;
using ComputeTypeA = typename Base::ComputeTypeA;
using ComputeTypeB = typename Base::ComputeTypeB;
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 2;
template <typename TileDesc_M0_M1_M2_K>
__host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&)
{
constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
constexpr index_t K2 = KPack;
constexpr index_t K1 = 64 / NPerXDL;
constexpr index_t K0 = KRepeat;
return transform_tensor_descriptor(
TileDesc_M0_M1_M2_K{},
make_tuple(
make_pass_through_transform(Number<M0>{}),
make_pass_through_transform(Number<M1>{}),
make_pass_through_transform(Number<M2>{}),
make_unmerge_transform(make_tuple(Number<K0>{}, Number<K1>{}, Number<K2>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}));
}
static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 =
MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k);
static constexpr auto ScalesPerKBlockSize =
KPerBlock / ScaleBlockSize; // How many mx-vectors per K block
//> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run()
static constexpr auto ScalesPerXdlopsRun = (KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize;
//> How many scales a thread must read to accommodate one call to xdlops_gemm.Run()
static constexpr auto ScalesPerXdlopsRunPerThread =
ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks;
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
{
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
}
template <bool HasMainLoop,
TailNumber TailNum,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer,
typename AScaleGridBuffer,
typename AScaleGridDesc,
typename AScaleThreadTransfer,
typename BScaleGridBuffer,
typename BScaleGridDesc,
typename BScaleThreadTransfer>
__device__ void Run(
// ABlockCopy
const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
// BBlockCopy
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
// CThread
CThreadBuffer& c_thread_buf,
// A and B scales
const AScaleGridDesc& a_scale_grid_desc,
AScaleThreadTransfer& a_scale_thread_copy,
const AScaleGridBuffer& a_scale_grid_buf,
const BScaleGridDesc& b_scale_grid_desc,
BScaleThreadTransfer& b_scale_thread_copy,
const BScaleGridBuffer& b_scale_grid_buf,
index_t num_loop) const
{
ignore = b_block_desc;
ignore = b_block_buf;
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
b_thread_desc_.GetElementSpaceSize());
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
a_scale_thread_desc.GetElementSpaceSize());
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
b_scale_thread_desc.GetElementSpaceSize());
StaticallyIndexedArray<decltype(a_scale_thread_buf), Number<2>{}> a_scale_thread_bufs;
StaticallyIndexedArray<decltype(b_scale_thread_buf), Number<2>{}> b_scale_thread_bufs;
// Global prefetch A1 B1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(I0));
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Prefetch a_scales
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
constexpr auto a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s));
auto a_scale_thread_buf_copy =
make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
a_scale_thread_desc_copy.GetElementSpaceSize());
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc_copy,
make_tuple(I0, I0),
a_scale_thread_buf_copy);
a_scale_thread_buf(I0)(Number<a_scale_offset>{}) =
a_scale_thread_buf_copy[Number<0>{}];
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc,
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
});
});
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize));
});
// restore row id and advance to the next set of scales
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
make_multi_index(-MPerBlock, ScalesPerKBlockSize));
// Prefetch b_scales to buf 0
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
constexpr auto b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
auto b_scale_thread_buf_copy =
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
b_scale_thread_desc_copy.GetElementSpaceSize());
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc_copy,
make_tuple(I0, I0),
b_scale_thread_buf_copy);
b_scale_thread_bufs(I0)(Number<b_scale_offset>{}) =
b_scale_thread_buf_copy[Number<0>{}];
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc,
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
});
});
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
});
// restore col id and advance to the next set of scales
// NWaves * NPerXDL * NRepeat == NPerBlock
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
make_multi_index(-NPerBlock, ScalesPerKBlockSize));
__builtin_amdgcn_sched_barrier(0);
// Local prefill A1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
// Global prefetch A2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
// Prefetch a_scales to buf 1
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
constexpr auto a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s));
auto a_scale_thread_buf_copy =
make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
a_scale_thread_desc_copy.GetElementSpaceSize());
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc_copy,
make_tuple(I0, I0),
a_scale_thread_buf_copy);
a_scale_thread_buf(I1)(Number<a_scale_offset>{}) =
a_scale_thread_buf_copy[Number<0>{}];
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc,
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
});
});
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize));
});
// restore row id and advance to the next set of scales
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
make_multi_index(-MPerBlock, ScalesPerKBlockSize));
// Prefetch b_scales to buf 1
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
constexpr auto b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
auto b_scale_thread_buf_copy =
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
b_scale_thread_desc_copy.GetElementSpaceSize());
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc_copy,
make_tuple(I0, I0),
b_scale_thread_buf_copy);
b_scale_thread_bufs(I1)(Number<b_scale_offset>{}) =
b_scale_thread_buf_copy[Number<0>{}];
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc,
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
});
});
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
});
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
make_multi_index(-NPerBlock, ScalesPerKBlockSize));
// Local prefetch A1
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step = k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
});
});
// Initialize C
c_thread_buf.Clear();
// main body
if constexpr(HasMainLoop)
{
// loop over k with the step KPerBlock
index_t i = 0;
do
{
auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(local_read_buf));
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf);
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
static_assert(
0 < ScalesPerXdlopsRunPerThread,
"Must have at least one scale per Xdlops per Thread.");
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread>
a_scale_thread_vec;
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread>
b_scale_thread_vec;
// Pack scale_thread_buf into scale_thread_vec
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs[mfma_reg_buf]
[Number<a_scale_offset + s>{}];
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs[mfma_reg_buf]
[Number<b_scale_offset + s>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops /
APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops /
BPackedSize>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
// MFMA accumulation
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<AScaleDataType>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<BScaleDataType>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
block_sync_lds();
// a thread copy
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step =
k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}(
[&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step + chunk * KThreadChunk *
xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
});
});
// Prefetch a_scales
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc,
make_tuple(I0, I0, I0),
a_scale_thread_bufs(mfma_reg_buf));
// restore row id and advance to the next set of scales
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, make_multi_index(0, ScalesPerKBlockSize, 0));
// Prefetch b_scales
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
constexpr auto b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
auto b_scale_thread_buf_copy =
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
b_scale_thread_desc_copy.GetElementSpaceSize());
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc_copy,
make_tuple(I0, I0),
b_scale_thread_buf_copy);
b_scale_thread_bufs(mfma_reg_buf)(Number<b_scale_offset>{}) =
b_scale_thread_buf_copy[Number<0>{}];
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc,
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
});
});
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc,
make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
});
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize));
};
LoopFunc(I0, I1);
LoopFunc(I1, I0);
i += 2;
} while(i < (num_loop - 2));
}
// tail
if constexpr(TailNum == TailNumber::Even)
{
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_n0_n1_k0_k1,
b_block_origin_idx,
b_thread_bufs(I1));
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread> a_scale_thread_vec;
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread> b_scale_thread_vec;
// Pack b_scale_thread_buf into b_scale_thread_vec
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs[I0][Number<a_scale_offset + s>{}];
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs[I0][Number<b_scale_offset + s>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
// MFMA accumulation
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<AScaleDataType>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<BScaleDataType>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
block_sync_lds();
// a thread copy
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step =
k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
});
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread> a_scale_thread_vec;
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread> b_scale_thread_vec;
// Pack b_scale_thread_buf into b_scale_thread_vec
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs[I1][Number<a_scale_offset + s>{}];
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs[I1][Number<b_scale_offset + s>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
// MFMA accumulation
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<AScaleDataType>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<BScaleDataType>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
else if constexpr(TailNum == TailNumber::Odd)
{
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread> a_scale_thread_vec;
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread> b_scale_thread_vec;
// Pack b_scale_thread_buf into b_scale_thread_vec
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_bufs[I0][Number<a_scale_offset + s>{}];
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_bufs[I0][Number<b_scale_offset + s>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
// MFMA accumulation
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<AScaleDataType>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<BScaleDataType>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
}
}
// TODO: make this field protected when a_scale_thread_copy_ is moved
// here
static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<KRepeat>{}, Number<ScalesPerXdlopsRunPerThread>{}));
// Is used to copy data from a_scale_grid to a_scale_thread
static constexpr auto a_scale_thread_desc_copy =
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{}));
// TODO: make this field protected when b_scale_thread_copy_ is moved
// here
static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<NRepeat>{}, Number<KRepeat>{}, Number<ScalesPerXdlopsRunPerThread>{}));
// Is used to copy data from b_scale_grid to b_scale_thread_buf
static constexpr auto b_scale_thread_desc_copy =
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{}));
protected:
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}));
using Base::a_thread_copy_;
using Base::a_thread_desc_;
using Base::b_thread_copy_;
// using Base::b_thread_desc_;
using Base::c_thread_desc_;
static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
};
} // namespace ck

View File

@@ -0,0 +1,109 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_v3.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_gufusion_v3.hpp"
namespace ck {
template <BlockGemmPipelineVersion BlkGemmPipelineVer,
BlockGemmPipelineScheduler BlkGemmPipeSche,
index_t ThreadBlockSize,
index_t ScaleBlockSize,
typename ADataType,
typename AScaleDataType,
typename BDataType,
typename BScaleDataType,
typename ComputeDataType, // TODO: remove this as in this pipeline ADataType and BDataType
// must be used for compute
typename AccDataType,
typename ATileDesc,
typename BTileDesc,
typename AMmaTileDesc,
typename BMmaTileDesc,
index_t ABlockTransferSrcScalarPerVector,
index_t BBlockTransferSrcScalarPerVector,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
bool GUFusion = false>
constexpr auto BlockGemmMXPipeline_Selector()
{
// Hardware MX GEMM pipeline
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if constexpr(GUFusion)
{
return nullptr;
}
else
{
return nullptr;
}
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if constexpr(GUFusion)
{
return BlockwiseGemmXdlops_pipeline_mx_moe_bns_gufusion_v3<
BlkGemmPipeSche,
ThreadBlockSize,
ScaleBlockSize,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
else
{
return BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlkGemmPipeSche,
ThreadBlockSize,
ScaleBlockSize,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
}
else
{
std::cerr << "MX GEMM Pipeline configuration is not available" << std::endl;
}
}
} // namespace ck

View File

@@ -0,0 +1,405 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace ck {
/**
* Transfer that uses direct load instructions to copy data from global to LDS memory.
*
* Traditional loads first copy data from global to registers, and then from registers to LDS.
* Direct loads do not need an intermediate step, data is copied directly from global to LDS,
* without the use of additional registers.
*
* However, the instruction has limitations:
* - each thread must copy exactly a single DWORD - 4 bytes;
* - threads within a single wavefront must write consecutive DWORDS into LDS,
* (data in global do not need to be contiguous, each thread might have its own offset).
*
* To make sure that all the transfers finished, the `waitcnt` instruction must be used with
* `vmcnt` instead of `lgkmcnt`.
*
* Limitations of the transfer class:
* - `SrcData` must be the same as `DstData` - no possibility to convert the data type in flight;
* - `DstVectorDim` must be the last dimension;
* - `SrcVectorDim` must be the last dimension if `ScalarPerVector` is greater than 1;
* - `ScalarPerVector` times the number of bytes of `DstData` must be equal to a single DWORD = 4B
* (for examlpe if `DstData` is fp32, then `ScalarPerVector` must be 1; if `DstData` is fp16,
* `ScalarPerVector` must be 2);
* - if `ScalarPerVector` is greater than 1, the contiguous dimension in src and dst must be
* the same dimension;
* - threads in a wavefront must write contiguous data to LDS (when wavefront size is 64,
* they must write 64 contiguous DWORDs) - `ThreadClusterLengths` must be prepared in such a way
* to guarantee that.
*/
template <typename ThreadGroup,
typename BlockSliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename SrcDimAccessOrder,
index_t SrcVectorDim,
index_t DstVectorDim,
index_t ScalarPerVector,
typename IndexType,
index_t GatherDim = 1>
struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad
{
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
using Index = MultiIndex<nDim>;
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto block_slice_lengths = BlockSliceLengths{};
static constexpr auto thread_cluster_lengths = ThreadClusterLengths{};
static constexpr auto thread_single_load_size = generate_sequence(
detail::lambda_scalar_per_access<DstVectorDim, ScalarPerVector>{}, Number<nDim>{});
// After a load, each thread moves by `thread_steps` instead of loading the next elements.
// It makes the whole wavefront load contiguous memory, what is required for direct loads.
static constexpr auto thread_steps = thread_cluster_lengths * thread_single_load_size;
static constexpr auto thread_slice_lengths = block_slice_lengths / thread_steps;
static constexpr index_t gather_num = thread_slice_lengths.At(Number<GatherDim>{});
static __device__ constexpr bool AreThreadClusterLengthsValid()
{
// Make sure that ThreadClusterLengths are set in a way that allows for contiguous writes to
// LDS by the threads from a single wavefront.
// Examples (assuming 64 threads in a wavefront, 128 in a thread block):
// 1. BlockSliceLengths = [K0PerBlock, MPerBlock, K1PerBlock] = [4, 128, 8],
// data type = fp32 -> ScalarPerVector = 1
// INVALID: ThreadClusterLengths = [4, 4, 8] since in the first iteration, threads 0-31
// write [0, 0, 0] - [0, 3, 7] and thread 32 writes [1, 0, 0] instead of
// [0, 4, 0].
// VALID: ThreadClusterLengths = [2, 8, 8] or [1, 16, 8] since in the first iteration,
// threads 0-63 write [0, 0, 0] - [0, 7, 7] -> 64 consecutive elements (DWORDs).
// 2. BlockSliceLengths = [K0PerBlock, MPerBlock, K1PerBlock] = [4, 128, 8],
// data type = fp16 -> ScalarPerVector = 2
// NOTE: ThreadClusterLengths must take into account that each thread writes two
// elements (single DWORD) along the contiguous dimension.
// INVALID: ThreadClusterLengths = [4, 4, 8] since each 8 threads would try to write
// 8 * 2 elements of K1PerBlock and there are only 8;
// ThreadClusterLengths = [4, 8, 4] since in the first iteration, threads 0-31
// write [0, 0, 0] - [0, 7, 7] (7 since each writes 2 elements) and thread 32
// writes [1, 0, 0] instead of [0, 8, 0].
// VALID: ThreadClusterLengths = [4, 16, 4] or [2, 32, 4] or [1, 64, 4] since in the
// first iteration, threads 0-63 write [0, 0, 0] - [0, 15, 7] -> 128 consecutive
// elements = 64 consecutive DWORDs.
#if defined(__gfx950__)
int num_contiguous_dwords = 4;
#else
int num_contiguous_dwords = 1;
#endif
bool is_contiguous = true;
static_for<0, nDim, 1>{}([&](auto i) {
if(is_contiguous)
{
num_contiguous_dwords *= thread_cluster_lengths[nDim - i - 1];
}
if(thread_slice_lengths[nDim - i - 1] > 1)
{
is_contiguous = false;
}
});
constexpr index_t wavefront_size = get_warp_size();
const bool wave_contiguous = num_contiguous_dwords % wavefront_size == 0;
bool thread_slice_lengths_correct = true;
static_for<0, nDim, 1>{}([&](auto i) {
if(thread_slice_lengths[i] <= 0)
{
thread_slice_lengths_correct = false;
}
});
return wave_contiguous && thread_slice_lengths_correct;
}
__device__ constexpr ThreadGroupTensorSliceTransfer_Gather_DirectLoad(
const SrcDesc& src_desc,
const Index& src_block_slice_origin,
const DstDesc& dst_desc,
const Index& dst_block_slice_origin,
const StaticallyIndexedArray<IndexType, gather_num>& gather_offsets)
: gather_offsets_(gather_offsets)
{
static_assert(ck::is_same_v<SrcData, DstData>,
"Direct load transfer does not support datatypes conversion. Source and "
"destination data types must be the same.");
static_assert(
DstVectorDim == nDim - 1,
"Direct load transfer requires the destination vector dimension to be the last one.");
static_assert(ScalarPerVector == 1 || SrcVectorDim == DstVectorDim,
"When loading more than one element per thread at once, the contiguous "
"dimension must be the same between source and destination.");
// constexpr auto dword_bytes = 4;
// constexpr auto bytes_per_thread_load = ScalarPerVector * sizeof(SrcData);
// static_assert(bytes_per_thread_load == dword_bytes,
// "Direct load transfer requires each thread to load exactly a single "
// "DWORD of data.");
static_assert(nDim == remove_cvref_t<SrcDesc>::GetNumOfDimension() &&
nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() &&
nDim == ThreadClusterLengths::Size(),
"Inconsistent number of dimensions across lengths and descriptors.");
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
"The number of threads cannot be less than the number of elements in "
"thread cluster lengths.");
// static_assert(
// AreThreadClusterLengthsValid(),
// "Thread cluster lengths are incorrect. They must be set in a way that allows a single
// " "wavefront to write contiguous DWORDs into LDS memory. ");
const auto thread_cluster_idx =
thread_cluster_desc_.CalculateBottomIndex(make_multi_index(ThreadGroup::GetThreadId()));
constexpr auto wave_cluster_lengths = generate_sequence_v2(
[&](auto i) {
if constexpr(ThreadClusterArrangeOrder{}.At(i) == (nDim - 3))
{
return Number<ThreadGroup::GetNumOfThread() / 64>{};
}
else
{
return I1;
}
},
Number<nDim>{});
constexpr auto wave_thread_cluster_lengths = ThreadClusterLengths{} / wave_cluster_lengths;
constexpr auto wave_single_load_size =
wave_thread_cluster_lengths * thread_single_load_size;
constexpr auto wave_cluster_desc_ =
make_cluster_descriptor(wave_cluster_lengths, ThreadClusterArrangeOrder{});
const auto wave_cluster_idx = wave_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId() / 64));
const auto thread_data_idx_begin = thread_cluster_idx * thread_single_load_size;
const auto wave_data_idx_begin = wave_cluster_idx * wave_single_load_size;
SetSrcSliceOrigin(src_desc, src_block_slice_origin + thread_data_idx_begin);
// We don't need threadwise offset for lds since it was calculate by HW
// We still need input the wavewise offset.
SetDstSliceOrigin(dst_desc, dst_block_slice_origin + wave_data_idx_begin);
}
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
{
auto adjusted_src_origin_idx = [&]() {
Index idx;
static_for<0, nDim, 1>{}([&](auto i) {
idx(i) = i.value == GatherDim ? 0 : src_slice_origin_idx[Number<i>{}];
});
return idx;
}();
// CK_PRINT<decltype(adjusted_src_origin_idx)>();
// CK_PRINT<decltype(src_slice_origin_idx)>();
src_coord_ = make_tensor_coordinate(src_desc, adjusted_src_origin_idx);
src_slice_origin_ = adjusted_src_origin_idx;
}
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
{
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
dst_slice_origin_ = dst_slice_origin_idx;
}
__device__ void ResetDstSliceWindow(const DstDesc& dst_desc)
{
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_);
}
template <typename SrcBuffer, typename DstBuffer>
__device__ void Run(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
const DstDesc& dst_desc,
DstBuffer& dst_buf)
{
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global,
"Source data must come from a global memory buffer.");
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
"Destination data must be stored in an LDS memory buffer.");
static_assert(
ck::is_same_v<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>,
"SrcBuffer and SrcData data types must be consistent.");
static_assert(
ck::is_same_v<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>,
"DstBuffer and DstData data types must be consistent.");
constexpr auto dst_access_lengths = thread_slice_lengths;
const auto dst_forward_steps = generate_steps(dst_desc, 1);
const auto dst_backward_steps = generate_steps(dst_desc, -1);
const auto src_forward_steps = generate_steps(src_desc, 1);
const auto src_backward_steps = generate_steps(src_desc, -1);
// Loop over the destination block and copy data.
static_ford<decltype(dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
IndexType gather_offset = gather_offsets_[ordered_dst_access_idx[Number<GatherDim>{}]];
// src_coord_xor_ = src_coord_;
// src_coord_xor_.GetIndex().At(I0) =
// src_coord_.GetIndex().At(I0) ^ ((threadIdx.x % 64) / 8);
Index new_index = src_coord_.GetIndex();
new_index(I0) = src_coord_.GetIndex().At(I0) ^ ((threadIdx.x % 64) / 8);
src_coord_xor_ = make_tensor_coordinate(src_desc, new_index);
const IndexType src_offset = src_coord_xor_.GetOffset() + gather_offset;
const IndexType dst_offset = __builtin_amdgcn_readfirstlane(dst_coord_.GetOffset());
// Check if src data is not in the logic padding area.
// Leave the HW for oob checking
// const bool is_src_valid =
// coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc,
// src_coord_);
src_buf.template DirectCopyToLds<remove_cvref_t<decltype(dst_buf)>, ScalarPerVector>(
dst_buf, src_offset, dst_offset, true);
constexpr auto move_src_on_dim = [&]() constexpr
{
StaticallyIndexedArray<bool, nDim> move_on_dim_;
static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim_(i) = ordered_dst_access_idx[i] < dst_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim_(i) &= ordered_dst_access_idx[j] == dst_access_lengths[j] - 1;
});
move_on_dim_(i) &= i.value != GatherDim;
});
return move_on_dim_;
}
();
constexpr auto move_dst_on_dim = [&]() constexpr
{
StaticallyIndexedArray<bool, nDim> move_on_dim_;
static_for<0, nDim, 1>{}([&](auto i) {
move_on_dim_(i) = ordered_dst_access_idx[i] < dst_access_lengths[i] - 1;
static_for<i + 1, nDim, 1>{}([&](auto j) {
move_on_dim_(i) &= ordered_dst_access_idx[j] == dst_access_lengths[j] - 1;
});
});
return move_on_dim_;
}
();
// Decide whether to move forward or backward.
constexpr auto forward_sweep = [&]() {
StaticallyIndexedArray<bool, nDim> forward_sweep_;
forward_sweep_(I0) = true;
static_for<1, nDim, 1>{}([&](auto i) {
index_t tmp = ordered_dst_access_idx[I0];
static_for<1, i, 1>{}([&](auto j) {
tmp = tmp * dst_access_lengths[j] + ordered_dst_access_idx[j];
});
forward_sweep_(i) = tmp % 2 == 0;
});
return forward_sweep_;
}();
static_for<0, nDim, 1>{}([&](auto i) {
// Move the source coordinate.
if constexpr(move_src_on_dim[i])
{
if constexpr(forward_sweep[i])
{
move_tensor_coordinate(src_desc, src_coord_, src_forward_steps[i]);
}
else
{
move_tensor_coordinate(src_desc, src_coord_, src_backward_steps[i]);
}
}
// Move the destination coordinate.
if constexpr(move_dst_on_dim[i])
{
if constexpr(forward_sweep[i])
{
move_tensor_coordinate(dst_desc, dst_coord_, dst_forward_steps[i]);
}
else
{
move_tensor_coordinate(dst_desc, dst_coord_, dst_backward_steps[i]);
}
}
});
});
// Reset the destination slice since the entire buffer has been already filled.
ResetDstSliceWindow(dst_desc);
}
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
{
src_slice_origin_ = src_slice_origin_ + step;
src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_);
}
template <typename DescType>
__device__ auto generate_steps(const DescType& desc, int sign)
{
return generate_tuple(
[&](auto i) {
Index step_idx;
static_for<0, nDim, 1>{}([&](auto j) {
step_idx(j) = (i.value == j.value) ? sign * thread_steps[i] : 0;
});
return make_tensor_coordinate_step(desc, step_idx);
},
Number<nDim>{});
}
private:
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
SrcCoord src_coord_;
SrcCoord src_coord_xor_;
DstCoord dst_coord_;
Index src_slice_origin_;
Index dst_slice_origin_;
StaticallyIndexedArray<IndexType, gather_num> gather_offsets_;
// static constexpr auto a_grid_xor_desc = make_naive_tensor_descriptor_packed(
// make_tuple(Number<AK0 ^ ((threadIdx / AK0) % AK0)>{}, Number<M>{}, Number<AK1>{}));
};
} // namespace ck

View File

@@ -194,10 +194,10 @@ struct DeviceMoeGemmMX : public DeviceMoEGemmMXBPreShuffle<ALayout,
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
sizeof(ADataType) / APackedSize;
auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
sizeof(BDataType) / BPackedSize;
auto size_a_buffer =
a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType);
auto size_b_buffer =
b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType);
const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
@@ -245,49 +245,31 @@ struct DeviceMoeGemmMX : public DeviceMoEGemmMXBPreShuffle<ALayout,
}
};
constexpr auto estimated_reg_a = MPerBlock * KPerBlock * sizeof(ADataType) /
APackedSize / BlockSize / 4 *
(1 + GridwiseGemm::NWave);
constexpr auto estimated_reg_b = NPerBlock * KPerBlock * sizeof(BDataType) /
BPackedSize / BlockSize / 4 * (2) *
(IsInputGemm ? 2 : 1);
constexpr auto estimated_reg_c = MPerBlock * NPerBlock * sizeof(GemmAccDataType) /
BlockSize / 4 * (IsInputGemm ? 2 : 1);
constexpr auto estimated_reg_total =
estimated_reg_a + estimated_reg_b + estimated_reg_c;
constexpr index_t minimum_occupancy = (estimated_reg_total >= 256) ? 1 : 2;
// TODO: Check if this is the right algorithm for minimum_occupancy
constexpr index_t minimum_occupancy =
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave
? (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 &&
MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) <= 128 * 128 * 64 * 2)
? 2
: 1
: 2;
constexpr auto MemoryDataOp =
IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd;
if(has_main_k_block_loop)
{
// Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
true,
MemoryDataOp,
minimum_occupancy,
TailNumber::Odd>;
RunKernel(kernel);
}
else
{
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
true,
MemoryDataOp,
minimum_occupancy,
TailNumber::Even>;
RunKernel(kernel);
}
}
const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
true,
MemoryDataOp,
minimum_occupancy,
TailNumber::Full>;
RunKernel(kernel);
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
@@ -315,26 +297,15 @@ struct DeviceMoeGemmMX : public DeviceMoEGemmMXBPreShuffle<ALayout,
}
else
{
// Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
false,
MemoryDataOp,
minimum_occupancy,
TailNumber::Odd>;
RunKernel(kernel);
}
else
{
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
false,
MemoryDataOp,
minimum_occupancy,
TailNumber::Even>;
RunKernel(kernel);
}
const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
false,
MemoryDataOp,
minimum_occupancy,
TailNumber::Full>;
RunKernel(kernel);
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{

View File

@@ -0,0 +1,567 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename ADataType,
typename AScaleDataType,
typename BDataType,
typename BScaleDataType,
typename DsDataType,
typename CDataType,
typename GemmAccDataType,
typename CShuffleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
index_t ScaleBlockSize,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerXDL,
index_t NPerXDL,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
typename CDEShuffleBlockTransferScalarPerVectors,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
index_t ActivationOP = 0,
bool NSwizzle = false,
bool IsInputGemm = true,
bool MulRoutedWeight = true,
typename IndexType = index_t,
typename ComputeTypeA = ADataType,
typename ComputeTypeB = BDataType>
struct DeviceMoeGemmMXBPreShuffle : public DeviceMoEGemmMXBPreShuffle<ALayout,
BLayout,
DsLayout,
CLayout,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
DsDataType,
CDataType,
ScaleBlockSize,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
static constexpr index_t NumDTensor = DsDataType::Size();
using GridwiseGemm = GridwiseMoeGemmMX_BPreshuffle<
ALayout,
BLayout,
DsLayout,
CLayout,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
GemmAccDataType,
CShuffleDataType,
DsDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
ScaleBlockSize,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ActivationOP,
NSwizzle,
IsInputGemm,
MulRoutedWeight,
IndexType,
ComputeTypeA,
ComputeTypeB>;
using Argument = typename GridwiseGemm::Argument;
static constexpr index_t APackedSize = packed_size_v<ADataType>;
static constexpr index_t BPackedSize = packed_size_v<BDataType>;
int GetPreShuffleParameters() override { return NPerXDL; }
// Invoker
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
{
arg.Print();
}
if(!GridwiseGemm::CheckValidity(arg))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
float ave_time = 0;
index_t k_grain = arg.KBatch * KPerBlock;
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
const auto RunKernel = [&](const auto& kernel) {
if(stream_config.flush_cache)
{
std::array<std::size_t, NumDTensor> DsSize;
Argument arg_ = arg;
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
auto size_a_buffer =
a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType);
auto size_b_buffer =
b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType);
const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
});
ck::utility::RotatingMemWrapperMultiD<Argument, DsDataType> rotating_mem(
arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize);
rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck::utility::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if(arg_.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
0,
arg_.M * arg_.N * sizeof(CDataType),
stream_config.stream_id_));
};
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
stream_config,
run_flush_cache,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg_);
}
else
{
if(arg.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
0,
arg.M * arg.N * sizeof(CDataType),
stream_config.stream_id_));
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
}
};
// TODO: Check if this is the right algorithm for minimum_occupancy
constexpr index_t minimum_occupancy =
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave
? (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 &&
MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) <= 128 * 128 * 64 * 2)
? 2
: 1
: 2;
constexpr auto MemoryDataOp =
IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd;
if(has_main_k_block_loop)
{
// Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
true,
MemoryDataOp,
minimum_occupancy,
TailNumber::Odd>;
RunKernel(kernel);
}
else
{
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
true,
MemoryDataOp,
minimum_occupancy,
TailNumber::Even>;
RunKernel(kernel);
}
}
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
true,
MemoryDataOp,
minimum_occupancy,
TailNumber::Odd>;
RunKernel(kernel);
}
else
{
const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
true,
MemoryDataOp,
minimum_occupancy,
TailNumber::Even>;
RunKernel(kernel);
}
}
else
{
throw std::runtime_error("todo: only v1 & v3 support now");
}
}
else
{
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
false,
MemoryDataOp,
minimum_occupancy,
TailNumber::Odd>;
RunKernel(kernel);
}
else
{
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
false,
MemoryDataOp,
minimum_occupancy,
TailNumber::Even>;
RunKernel(kernel);
}
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
false,
MemoryDataOp,
minimum_occupancy,
TailNumber::Odd>;
RunKernel(kernel);
}
else
{
const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
false,
MemoryDataOp,
minimum_occupancy,
TailNumber::Even>;
RunKernel(kernel);
}
}
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
// only impl kbatch 1 now
if(arg.KBatch > 1)
{
return false;
}
if(!ck::is_xdl_supported())
{
return false;
}
if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
{
return false;
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding))
{
return false;
}
if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
{
return false;
}
return GridwiseGemm::CheckValidity(arg);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const void* p_sorted_token_ids,
const void* p_sorted_expert_ids,
const void* p_max_token_id,
const void* p_a,
const void* p_a_scale,
const void* p_b,
const void* p_b_scale,
std::array<const void*, NumDTensor> p_ds,
void* p_c,
index_t NumTokens,
index_t TopK,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideScaleA,
index_t StrideB,
index_t StrideScaleB,
std::array<index_t, NumDTensor> StrideDs,
index_t StrideC,
index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{static_cast<const index_t*>(p_sorted_token_ids),
static_cast<const index_t*>(p_sorted_expert_ids),
static_cast<const index_t*>(p_max_token_id),
static_cast<const ADataType*>(p_a),
static_cast<const AScaleDataType*>(p_a_scale),
static_cast<const BDataType*>(p_b),
static_cast<const BScaleDataType*>(p_b_scale),
p_ds,
static_cast<CDataType*>(p_c),
NumTokens,
TopK,
M,
N,
K,
StrideA,
StrideScaleA,
StrideB,
StrideScaleB,
StrideDs,
StrideC,
KBatch,
a_element_op,
b_element_op,
c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_a_scale,
const void* p_b,
const void* p_b_scale,
std::array<const void*, NumDTensor> p_ds,
void* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideScaleA,
index_t StrideB,
index_t StrideScaleB,
std::array<ck::index_t, NumDTensor> StrideDs,
index_t StrideC,
index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override
{
return std::make_unique<Argument>(nullptr,
nullptr,
nullptr,
static_cast<const ADataType*>(p_a),
static_cast<const AScaleDataType*>(p_a_scale),
static_cast<const BDataType*>(p_b),
static_cast<const BScaleDataType*>(p_b_scale),
p_ds,
static_cast<CDataType*>(p_c),
M, // randoms set, no use
0,
M,
N,
K,
StrideA,
StrideScaleA,
StrideB,
StrideScaleB,
StrideDs,
StrideC,
KBatch,
a_element_op,
b_element_op,
c_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
{BlockGemmPipelineVersion::v1, "v1"},
{BlockGemmPipelineVersion::v2, "v2"},
{BlockGemmPipelineVersion::v3, "v3"},
{BlockGemmPipelineVersion::v4, "v4"},
{BlockGemmPipelineVersion::v5, "v5"}};
// clang-format off
str << "DeviceMoeGEmmMx"
<< "<"
<< getGemmSpecializationString(GemmSpec) << ", "
<< std::string(ALayout::name)[0]
<< std::string(BLayout::name)[0]
<< std::string(CLayout::name)[0]
<< ">"
<< " BlkSize: "
<< BlockSize << ", "
<< "BlkTile: "
<< MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
<< "WaveTile: "
<< MPerXDL<<"x"<<NPerXDL << ", "
<< "WaveMap: "
<< MXdlPerWave<<"x" << NXdlPerWave<<", "
<< "VmemReadVec: "
<< ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
<< "BlkGemmPipelineScheduler: "
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
<< "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
<< "BlkGemmPipelinePrefetchStages: "
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff