mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
[CK] Mxfp4 moe blockscale buf2lds version support (#2455)
* change cshuffle size
* added mxfp4 moe async buffer loading without B preshuffle
* added mx moe B shuffling + scale shuffling (async loads)
* minor fix
---------
Co-authored-by: mtgu0705 <mtgu@amd.com>
[ROCm/composable_kernel commit: 7998ae8969]
This commit is contained in:
@@ -22,16 +22,35 @@ add_example_dependencies(example_gemm_mx example_moe_gemm1_xdl_mx_fp4_bns)
|
||||
add_example_executable(example_moe_gemm2_xdl_mx_fp4_bns moe_gemm2_xdl_mx_fp4_bns.cpp)
|
||||
add_example_dependencies(example_gemm_mx example_moe_gemm2_xdl_mx_fp4_bns)
|
||||
|
||||
add_example_executable(example_moe_gemm1_xdl_mx_fp4 moe_gemm1_xdl_mx_fp4.cpp)
|
||||
add_example_dependencies(example_gemm_mx example_moe_gemm1_xdl_mx_fp4)
|
||||
|
||||
add_example_executable(example_moe_gemm2_xdl_mx_fp4 moe_gemm2_xdl_mx_fp4.cpp)
|
||||
add_example_dependencies(example_gemm_mx example_moe_gemm2_xdl_mx_fp4)
|
||||
|
||||
add_example_executable(example_moe_gemm1_xdl_mx_fp4_bpreshuffle moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp)
|
||||
add_example_dependencies(example_gemm_mx example_moe_gemm1_xdl_mx_fp4_bpreshuffle)
|
||||
|
||||
add_example_executable(example_moe_gemm2_xdl_mx_fp4_bpreshuffle moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp)
|
||||
add_example_dependencies(example_gemm_mx example_moe_gemm2_xdl_mx_fp4_bpreshuffle)
|
||||
|
||||
set(FP4_MXGEMM_OPTIONS)
|
||||
list(APPEND FP4_MXGEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --amdgpu-use-amdgpu-trackers=1")
|
||||
example_compile_options(example_gemm_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS})
|
||||
example_compile_options(example_gemm_mx_fp4_bpreshuffle PRIVATE ${FP4_MXGEMM_OPTIONS})
|
||||
|
||||
example_compile_options(example_moe_gemm1_xdl_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS})
|
||||
example_compile_options(example_moe_gemm2_xdl_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS})
|
||||
# mx moe B no-shuffling + scale shuffling
|
||||
example_compile_options(example_moe_gemm1_xdl_mx_fp4_bns PRIVATE ${FP4_MXGEMM_OPTIONS})
|
||||
example_compile_options(example_moe_gemm2_xdl_mx_fp4_bns PRIVATE ${FP4_MXGEMM_OPTIONS})
|
||||
|
||||
# mx moe B no-shuffling + scale shuffling (async loads)
|
||||
example_compile_options(example_moe_gemm1_xdl_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS})
|
||||
example_compile_options(example_moe_gemm2_xdl_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS})
|
||||
|
||||
# mx moe B shuffling + scale shuffling (async loads)
|
||||
example_compile_options(example_moe_gemm1_xdl_mx_fp4_bpreshuffle PRIVATE ${FP4_MXGEMM_OPTIONS})
|
||||
example_compile_options(example_moe_gemm2_xdl_mx_fp4_bpreshuffle PRIVATE ${FP4_MXGEMM_OPTIONS})
|
||||
|
||||
set(FP8_MXGEMM_OPTIONS)
|
||||
list(APPEND FP8_MXGEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32")
|
||||
example_compile_options(example_gemm_mx_fp8 PRIVATE ${FP8_MXGEMM_OPTIONS})
|
||||
|
||||
548
example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp
Normal file
548
example/67_gemm_microscaling/moe_gemm1_xdl_mx_fp4.cpp
Normal file
@@ -0,0 +1,548 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm1.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/fill.hpp"
|
||||
#include "ck/utility/blkgemmpipe_scheduler.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F4 = ck::f4x2_pk_t;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
using XDataType = ck::e8m0_bexp_t;
|
||||
using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using A0DataType = F4;
|
||||
using A1DataType = XPackedDataType;
|
||||
using B0DataType = F4;
|
||||
using B1DataType = XPackedDataType;
|
||||
using EDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F16;
|
||||
using D0DataType = F32;
|
||||
using D1DataType = F32;
|
||||
using D2DataType = F32;
|
||||
using DsDataType = ck::Tuple<D0DataType, D1DataType, D2DataType>;
|
||||
|
||||
using A0Layout = Row;
|
||||
using B0Layout = Col;
|
||||
using ELayout = Row;
|
||||
using D0Layout = Row;
|
||||
using D1Layout = Col;
|
||||
using D2Layout = ELayout;
|
||||
using DsLayout = ck::Tuple<D0Layout, D1Layout, D2Layout>;
|
||||
|
||||
// d0: ascale, d1: bscale, d2:expert weight
|
||||
struct MulABScaleExpertWeight
|
||||
{
|
||||
template <typename E, typename C, typename D0, typename D1, typename D2>
|
||||
__host__ __device__ constexpr void
|
||||
operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const;
|
||||
// for real kernel use
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<EDataType, F16, float, float, float>(
|
||||
EDataType& e, const F16& c, const float& d0, const float& d1, const float& d2) const
|
||||
{
|
||||
(void)d0;
|
||||
(void)d1;
|
||||
(void)d2;
|
||||
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
// for reference cpu
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<float, float, float, float, float>(
|
||||
float& e, const float& c, const float& d0, const float& d1, const float& d2) const
|
||||
{
|
||||
// for reference cpu
|
||||
(void)d0;
|
||||
(void)d1;
|
||||
(void)d2;
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
};
|
||||
|
||||
using CDEElementOp = MulABScaleExpertWeight;
|
||||
|
||||
// A, B Scale preshuffle
|
||||
template <bool KLast>
|
||||
void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K)
|
||||
{
|
||||
int MNXdlPack = 2;
|
||||
int KXdlPack = 2;
|
||||
|
||||
int XdlMNThread = 16;
|
||||
int XdlKThread = 64 / XdlMNThread;
|
||||
|
||||
int K0 = K / KXdlPack / XdlKThread; // KRepeat
|
||||
|
||||
// The 4 16x128 building blocks will be packed into 1 32x256 for F4
|
||||
// The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4
|
||||
|
||||
// unfold the MN32xK(256/32) scale buffer
|
||||
// 4 16 2 2
|
||||
// To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack
|
||||
// Then, MNRepeat->KRepeat
|
||||
|
||||
for(int n = 0; n < MN; ++n)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat
|
||||
int tempn = n % (XdlMNThread * MNXdlPack);
|
||||
int n1 = tempn % XdlMNThread; // i XdlMNThread
|
||||
int n2 = tempn / XdlMNThread; // i MNXdlPack
|
||||
|
||||
int k0 = k / (XdlKThread * KXdlPack); // i KRepeat
|
||||
int tempk = k % (XdlKThread * KXdlPack);
|
||||
int k1 = tempk % XdlKThread; // i XdlKThread
|
||||
int k2 = tempk / XdlKThread; // i KXdlPack
|
||||
|
||||
int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 +
|
||||
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
|
||||
k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack +
|
||||
k2 * MNXdlPack + n2;
|
||||
// src[n * K + k] = ck::type_convert<ck::e8m0_bexp_t>(static_cast<float>(powf(2.0f, n2 +
|
||||
// k2 * MNXdlPack)));
|
||||
if constexpr(KLast)
|
||||
dst[outputIndex] = src[n * K + k];
|
||||
else
|
||||
dst[outputIndex] = src[k * MN + n];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = MulABScaleExpertWeight;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
constexpr ck::index_t DataPackedSize = 2; // Packed representation of data
|
||||
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
|
||||
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
|
||||
static constexpr ck::index_t Nswizzle = false;
|
||||
static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul
|
||||
static constexpr ck::index_t MPerBlock = 128;
|
||||
static constexpr ck::index_t NPerBlock = 64;
|
||||
static constexpr ck::index_t BlockSize = 256;
|
||||
static constexpr bool MulRoutedWeight = true;
|
||||
|
||||
// clang-format off
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMX<
|
||||
A0Layout, B0Layout, DsLayout, ELayout,
|
||||
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
ScaleBlockSize, BlockSize,
|
||||
MPerBlock, NPerBlock, KPerBlock,
|
||||
16, 16,
|
||||
16, 16,
|
||||
4, 2,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
|
||||
2, 2, S<1, 32, 1, 8>, S<8, 1, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3,
|
||||
ActOP, Nswizzle, true, MulRoutedWeight, ck::index_t, A0DataType>;
|
||||
// clang-format on
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
|
||||
// per expert:
|
||||
// GEMM shape
|
||||
constexpr ck::index_t sorted_tile_num = 13;
|
||||
constexpr ck::index_t valid_tile_num = sorted_tile_num;
|
||||
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
|
||||
ck::index_t valid_size = valid_tile_num * MPerBlock;
|
||||
|
||||
ck::index_t N = 6144;
|
||||
ck::index_t K = 4096;
|
||||
ck::index_t experts = 8;
|
||||
ck::index_t tokens = 832;
|
||||
ck::index_t topk = 2;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
// use default case
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 7)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
N = std::stoi(argv[4]);
|
||||
K = std::stoi(argv[5]);
|
||||
tokens = std::stoi(argv[6]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 6: N, K, tokens\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
if(K % ScaleBlockSize != 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize.");
|
||||
};
|
||||
|
||||
ck::index_t StrideA = K;
|
||||
ck::index_t StrideB = K;
|
||||
ck::index_t StrideE = N;
|
||||
ck::index_t Scale_Stride_AM = (K + ScaleBlockSize - 1) / ScaleBlockSize;
|
||||
ck::index_t Scale_Stride_BN = (K + ScaleBlockSize - 1) / ScaleBlockSize;
|
||||
constexpr ck::index_t NumDTensor = DsDataType::Size();
|
||||
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0, 0};
|
||||
|
||||
ck::index_t KBatch = 1;
|
||||
|
||||
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
|
||||
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
|
||||
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({sorted_tile_num + 1}));
|
||||
max_token_id.mData[0] = valid_size;
|
||||
|
||||
if(tokens * topk > valid_size)
|
||||
{
|
||||
printf("err config, tokens * topk > valid_size\n");
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
for(int i = 0; i < sorted_tile_num; i++)
|
||||
{
|
||||
expert_ids.mData[i] = i / ck::math::integer_divide_ceil(valid_tile_num, experts);
|
||||
}
|
||||
int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num;
|
||||
int tokenid = 0;
|
||||
for(int i = 0; i < sorted_size; i++)
|
||||
{
|
||||
int tile_off = i % MPerBlock;
|
||||
if(tile_off < token_per_tile)
|
||||
{
|
||||
sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24);
|
||||
tokenid++;
|
||||
}
|
||||
else
|
||||
{
|
||||
sorted_token_ids.mData[i] = tokens;
|
||||
}
|
||||
}
|
||||
|
||||
expert_ids.savetxt("expert_ids.txt", "int");
|
||||
sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
|
||||
|
||||
Tensor<A0DataType> a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1}));
|
||||
Tensor<XDataType> a1_t_k(HostTensorDescriptor(
|
||||
{tokens, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
|
||||
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}));
|
||||
Tensor<XDataType> b1_e_n_k(
|
||||
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2},
|
||||
{(N * 2 * Scale_Stride_BN), 1, Scale_Stride_BN}));
|
||||
|
||||
// A, B Scale preshuffle
|
||||
Tensor<XDataType> a_scale_sorted(HostTensorDescriptor(
|
||||
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
|
||||
Tensor<XDataType> a_scale_preshuffled(HostTensorDescriptor(
|
||||
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
|
||||
Tensor<XDataType> b_scale_preshuffled(
|
||||
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2},
|
||||
{N * 2 * Scale_Stride_BN, 1, Scale_Stride_BN}));
|
||||
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
|
||||
Tensor<EDataType> e_t_k_n_host_result(
|
||||
HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
|
||||
Tensor<EDataType> e_t_k_n_device_result(
|
||||
HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
|
||||
|
||||
e_t_k_n_device_result.SetZero();
|
||||
std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl;
|
||||
std::cout << "a1_t_k: " << a1_t_k.mDesc << std::endl;
|
||||
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
|
||||
std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl;
|
||||
std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl;
|
||||
std::cout << "e_t_k_n: " << e_t_k_n_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-1, 1});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-1, 1});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0, 1.0});
|
||||
break;
|
||||
case 2:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{0.1f});
|
||||
break;
|
||||
case 3:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-1, 1});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-1, 1});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 4:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 5.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 5:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{1});
|
||||
break;
|
||||
case 6:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 7:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{0.5f});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{1.5f});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{1.0f});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{1.0f});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{0.1f});
|
||||
break;
|
||||
default:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
|
||||
}
|
||||
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.GetElementSpaceSize());
|
||||
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.GetElementSpaceSize());
|
||||
DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.GetElementSpaceSize());
|
||||
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.GetElementSpaceSize());
|
||||
DeviceMem a1_device_buf(sizeof(XDataType) * a_scale_sorted.GetElementSpaceSize());
|
||||
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.GetElementSpaceSize());
|
||||
DeviceMem b1_device_buf(sizeof(XDataType) * b1_e_n_k.GetElementSpaceSize());
|
||||
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) * e_t_k_n_device_result.GetElementSpaceSize());
|
||||
|
||||
// A scale sorted
|
||||
for(int i = 0; i < sorted_size; i++)
|
||||
{
|
||||
int token_id = sorted_token_ids.mData[i] & 0x00FFFFFF;
|
||||
|
||||
for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; k++)
|
||||
{
|
||||
if(token_id == tokens)
|
||||
{
|
||||
a_scale_sorted(i, k) = ck::type_convert<XDataType>(0);
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_sorted(i, k) = a1_t_k(token_id, k);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// A/B scale shuffle
|
||||
preShuffleScaleBuffer<ck::is_same_v<A0Layout, Row>>(a_scale_sorted.mData.data(),
|
||||
a_scale_preshuffled.mData.data(),
|
||||
sorted_size,
|
||||
K / ScaleBlockSize);
|
||||
preShuffleScaleBuffer<ck::is_same_v<B0Layout, Col>>(b1_e_n_k.mData.data(),
|
||||
b_scale_preshuffled.mData.data(),
|
||||
N * 2 * experts,
|
||||
K / ScaleBlockSize);
|
||||
|
||||
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
|
||||
expert_ids_dev.ToDevice(expert_ids.mData.data());
|
||||
max_token_id_dev.ToDevice(max_token_id.mData.data());
|
||||
a0_device_buf.ToDevice(a0_t_k.mData.data());
|
||||
b0_device_buf.ToDevice(b0_e_n_k.mData.data());
|
||||
a1_device_buf.ToDevice(a_scale_preshuffled.mData.data());
|
||||
b1_device_buf.ToDevice(b_scale_preshuffled.mData.data());
|
||||
d2_device_buf.ToDevice(d2_e_n.mData.data());
|
||||
e_device_buf.ToDevice(e_t_k_n_device_result.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto device_op = DeviceOpInstance{};
|
||||
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
auto argument = device_op.MakeArgument(
|
||||
sorted_token_ids_dev.GetDeviceBuffer(),
|
||||
expert_ids_dev.GetDeviceBuffer(),
|
||||
max_token_id_dev.GetDeviceBuffer(),
|
||||
a0_device_buf.GetDeviceBuffer(),
|
||||
a1_device_buf.GetDeviceBuffer(),
|
||||
b0_device_buf.GetDeviceBuffer(),
|
||||
b1_device_buf.GetDeviceBuffer(),
|
||||
std::array<const void*, NumDTensor>{nullptr, nullptr, d2_device_buf.GetDeviceBuffer()},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
tokens,
|
||||
topk,
|
||||
sorted_size,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
Scale_Stride_AM,
|
||||
StrideB,
|
||||
Scale_Stride_BN,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"))
|
||||
{
|
||||
std::cout << "This kernel support gfx942 and gfx950 only" << std::endl;
|
||||
}
|
||||
|
||||
if(time_kernel)
|
||||
{
|
||||
// not result correct here because output buf not setzero
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop =
|
||||
// FMA * tokens * N * (Gate+Up) * topk * K +
|
||||
// FMA * tokens * N * (Gate+Up) * topk * (K/BlockScale)
|
||||
std::size_t(2) * tokens * N * 2 * topk * K +
|
||||
std::size_t(2) * tokens * N * 2 * topk * K / ScaleBlockSize;
|
||||
|
||||
std::size_t num_btype = sizeof(A0DataType) / 2 * tokens * topk * K +
|
||||
sizeof(B0DataType) / 2 * K * N * 2 * experts +
|
||||
sizeof(XDataType) * tokens * topk * K / ScaleBlockSize +
|
||||
sizeof(XDataType) * K / ScaleBlockSize * N * 2 * experts +
|
||||
sizeof(EDataType) * tokens * topk * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << device_op.GetTypeString() << std::endl;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
// gemm2 use atomic, so need to reinit outputs
|
||||
e_device_buf.ToDevice(e_t_k_n_device_result.mData.data());
|
||||
invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1});
|
||||
|
||||
Tensor<float> c_t_k_n({tokens, topk, N}, {topk * N, N, 1});
|
||||
|
||||
using ReferenceGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceMoeMXGemm1<A0DataType,
|
||||
XDataType,
|
||||
B0DataType,
|
||||
XDataType,
|
||||
float, // CShuffleDataType,
|
||||
D2DataType,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
ActOP,
|
||||
MulRoutedWeight>;
|
||||
auto ref_moe_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_moe_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids,
|
||||
expert_ids,
|
||||
max_token_id,
|
||||
MPerBlock,
|
||||
a0_t_k,
|
||||
a1_t_k,
|
||||
b0_e_n_k,
|
||||
b1_e_n_k,
|
||||
d2_e_n,
|
||||
c_t_k_n,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
for(int m = 0; m < valid_size; ++m)
|
||||
{
|
||||
const int fuse_t = sorted_token_ids.mData[m];
|
||||
const int t = fuse_t & 0xffffff;
|
||||
const int topk_id = (fuse_t & 0xff000000) >> 24;
|
||||
|
||||
if(t >= tokens)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
e_t_k_n_host_result(t, topk_id, n) =
|
||||
ck::type_convert<EDataType>(c_t_k_n(t, topk_id, n));
|
||||
}
|
||||
}
|
||||
|
||||
e_device_buf.FromDevice(e_t_k_n_device_result.mData.data());
|
||||
|
||||
auto status =
|
||||
ck::utils::check_err(
|
||||
e_t_k_n_device_result, e_t_k_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1)
|
||||
? 0
|
||||
: 1;
|
||||
if(status == 0)
|
||||
{
|
||||
printf("Validation Pass.\n");
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -0,0 +1,574 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bpreshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm1.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/fill.hpp"
|
||||
#include "ck/utility/blkgemmpipe_scheduler.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F4 = ck::f4x2_pk_t;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
using XDataType = ck::e8m0_bexp_t;
|
||||
using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t
|
||||
using I64 = int64_t;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using A0DataType = F4;
|
||||
using A1DataType = XPackedDataType;
|
||||
using B0DataType = F4;
|
||||
using B1DataType = XPackedDataType;
|
||||
using EDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F16;
|
||||
using D0DataType = F32;
|
||||
using D1DataType = F32;
|
||||
using D2DataType = F32;
|
||||
using DsDataType = ck::Tuple<D0DataType, D1DataType, D2DataType>;
|
||||
|
||||
using A0Layout = Row;
|
||||
using B0Layout = Col;
|
||||
using ELayout = Row;
|
||||
using D0Layout = Row;
|
||||
using D1Layout = Col;
|
||||
using D2Layout = ELayout;
|
||||
using DsLayout = ck::Tuple<D0Layout, D1Layout, D2Layout>;
|
||||
|
||||
// d0: ascale, d1: bscale, d2:expert weight
|
||||
struct MulABScaleExpertWeight
|
||||
{
|
||||
template <typename E, typename C, typename D0, typename D1, typename D2>
|
||||
__host__ __device__ constexpr void
|
||||
operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const;
|
||||
// for real kernel use
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<EDataType, F16, float, float, float>(
|
||||
EDataType& e, const F16& c, const float& d0, const float& d1, const float& d2) const
|
||||
{
|
||||
(void)d0;
|
||||
(void)d1;
|
||||
(void)d2;
|
||||
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
// for reference cpu
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<float, float, float, float, float>(
|
||||
float& e, const float& c, const float& d0, const float& d1, const float& d2) const
|
||||
{
|
||||
// for reference cpu
|
||||
(void)d0;
|
||||
(void)d1;
|
||||
(void)d2;
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
};
|
||||
|
||||
using CDEElementOp = MulABScaleExpertWeight;
|
||||
|
||||
// B preshuffle
|
||||
void preShuffleBuffer(const F4* src, F4* dst, int N, int K, int NXdl)
|
||||
{
|
||||
int KPack = 16;
|
||||
int NLane = NXdl;
|
||||
int KLane = 64 / NLane;
|
||||
int K_pk = K / 2;
|
||||
int K0 = K_pk / (KLane * KPack);
|
||||
// K -> K0 KLane KPack
|
||||
// N -> N0 NLane
|
||||
// N, K -> N0 K0 KLane NLane KPack
|
||||
I64 tempk;
|
||||
for(I64 n = 0; n < N; ++n)
|
||||
{
|
||||
for(I64 k = 0; k < K_pk; ++k)
|
||||
{
|
||||
I64 n0 = n / NLane;
|
||||
I64 n1 = n % NLane;
|
||||
|
||||
I64 k0 = k / (KLane * KPack);
|
||||
tempk = k % (KLane * KPack);
|
||||
I64 k1 = tempk / KPack;
|
||||
I64 k2 = tempk % KPack;
|
||||
|
||||
I64 outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
|
||||
k1 * KPack * NLane + n1 * KPack + k2;
|
||||
|
||||
dst[outputIndex] = src[n * K_pk + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// A, B Scale preshuffle
|
||||
template <bool KLast>
|
||||
void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K)
|
||||
{
|
||||
int MNXdlPack = 2;
|
||||
int KXdlPack = 2;
|
||||
|
||||
int XdlMNThread = 16;
|
||||
int XdlKThread = 64 / XdlMNThread;
|
||||
|
||||
int K0 = K / KXdlPack / XdlKThread; // KRepeat
|
||||
|
||||
// The 4 16x128 building blocks will be packed into 1 32x256 for F4
|
||||
// The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4
|
||||
|
||||
// unfold the MN32xK(256/32) scale buffer
|
||||
// 4 16 2 2
|
||||
// To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack
|
||||
// Then, MNRepeat->KRepeat
|
||||
|
||||
for(int n = 0; n < MN; ++n)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat
|
||||
int tempn = n % (XdlMNThread * MNXdlPack);
|
||||
int n1 = tempn % XdlMNThread; // i XdlMNThread
|
||||
int n2 = tempn / XdlMNThread; // i MNXdlPack
|
||||
|
||||
int k0 = k / (XdlKThread * KXdlPack); // i KRepeat
|
||||
int tempk = k % (XdlKThread * KXdlPack);
|
||||
int k1 = tempk % XdlKThread; // i XdlKThread
|
||||
int k2 = tempk / XdlKThread; // i KXdlPack
|
||||
|
||||
int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 +
|
||||
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
|
||||
k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack +
|
||||
k2 * MNXdlPack + n2;
|
||||
// src[n * K + k] = ck::type_convert<ck::e8m0_bexp_t>(static_cast<float>(powf(2.0f, n2 +
|
||||
// k2 * MNXdlPack)));
|
||||
if constexpr(KLast)
|
||||
dst[outputIndex] = src[n * K + k];
|
||||
else
|
||||
dst[outputIndex] = src[k * MN + n];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = MulABScaleExpertWeight;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
constexpr ck::index_t DataPackedSize = 2; // Packed representation of data
|
||||
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
|
||||
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
|
||||
static constexpr ck::index_t Nswizzle = false;
|
||||
static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul
|
||||
static constexpr ck::index_t MPerBlock = 128;
|
||||
static constexpr bool MulRoutedWeight = true;
|
||||
|
||||
// clang-format off
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMXBPreShuffle<
|
||||
A0Layout, B0Layout, DsLayout, ELayout,
|
||||
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
ScaleBlockSize, 256,
|
||||
MPerBlock, 64, KPerBlock,
|
||||
16, 16,
|
||||
16, 16,
|
||||
4, 2,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
|
||||
2, 2, S<1, 32, 1, 8>, S<8, 1, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, true, MulRoutedWeight, ck::index_t, A0DataType>;
|
||||
// clang-format on
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
|
||||
// per expert:
|
||||
// GEMM shape
|
||||
constexpr ck::index_t sorted_tile_num = 13;
|
||||
constexpr ck::index_t valid_tile_num = sorted_tile_num;
|
||||
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
|
||||
ck::index_t valid_size = valid_tile_num * MPerBlock;
|
||||
|
||||
ck::index_t N = 6144;
|
||||
ck::index_t K = 4096;
|
||||
ck::index_t experts = 8;
|
||||
ck::index_t tokens = 832;
|
||||
ck::index_t topk = 2;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
// use default case
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 7)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
N = std::stoi(argv[4]);
|
||||
K = std::stoi(argv[5]);
|
||||
tokens = std::stoi(argv[6]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 6: N, K, tokens\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
if(K % ScaleBlockSize != 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize.");
|
||||
};
|
||||
|
||||
ck::index_t StrideA = K;
|
||||
ck::index_t StrideB = K;
|
||||
ck::index_t StrideE = N;
|
||||
ck::index_t Scale_Stride_AM = (K + ScaleBlockSize - 1) / ScaleBlockSize;
|
||||
ck::index_t Scale_Stride_BN = (K + ScaleBlockSize - 1) / ScaleBlockSize;
|
||||
constexpr ck::index_t NumDTensor = DsDataType::Size();
|
||||
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0, 0};
|
||||
|
||||
ck::index_t KBatch = 1;
|
||||
|
||||
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
|
||||
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
|
||||
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({sorted_tile_num + 1}));
|
||||
max_token_id.mData[0] = valid_size;
|
||||
|
||||
if(tokens * topk > valid_size)
|
||||
{
|
||||
printf("err config, tokens * topk > valid_size\n");
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
for(int i = 0; i < sorted_tile_num; i++)
|
||||
{
|
||||
expert_ids.mData[i] = i / ck::math::integer_divide_ceil(valid_tile_num, experts);
|
||||
}
|
||||
int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num;
|
||||
int tokenid = 0;
|
||||
for(int i = 0; i < sorted_size; i++)
|
||||
{
|
||||
int tile_off = i % MPerBlock;
|
||||
if(tile_off < token_per_tile)
|
||||
{
|
||||
sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24);
|
||||
tokenid++;
|
||||
}
|
||||
else
|
||||
{
|
||||
sorted_token_ids.mData[i] = tokens;
|
||||
}
|
||||
}
|
||||
|
||||
Tensor<A0DataType> a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1}));
|
||||
Tensor<XDataType> a1_t_k(HostTensorDescriptor(
|
||||
{tokens, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
|
||||
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}));
|
||||
Tensor<XDataType> b1_e_n_k(
|
||||
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2},
|
||||
{(N * 2 * Scale_Stride_BN), 1, Scale_Stride_BN}));
|
||||
// B preshuffle
|
||||
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}));
|
||||
|
||||
// A, B Scale preshuffle
|
||||
Tensor<XDataType> a_scale_sorted(HostTensorDescriptor(
|
||||
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
|
||||
Tensor<XDataType> a_scale_preshuffled(HostTensorDescriptor(
|
||||
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
|
||||
Tensor<XDataType> b_scale_preshuffled(
|
||||
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2},
|
||||
{N * 2 * Scale_Stride_BN, 1, Scale_Stride_BN}));
|
||||
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
|
||||
Tensor<EDataType> e_t_k_n_host_result(
|
||||
HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
|
||||
Tensor<EDataType> e_t_k_n_device_result(
|
||||
HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
|
||||
|
||||
e_t_k_n_device_result.SetZero();
|
||||
std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl;
|
||||
std::cout << "a1_t_k: " << a1_t_k.mDesc << std::endl;
|
||||
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
|
||||
std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl;
|
||||
std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl;
|
||||
std::cout << "e_t_k_n: " << e_t_k_n_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-1, 1});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-1, 1});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0, 1.0});
|
||||
break;
|
||||
case 2:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{0.1f});
|
||||
break;
|
||||
case 3:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-1, 1});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-1, 1});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{0.1f});
|
||||
break;
|
||||
case 4:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{0.1f});
|
||||
break;
|
||||
case 5:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{0.1f});
|
||||
break;
|
||||
case 6:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
default:
|
||||
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
|
||||
}
|
||||
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.GetElementSpaceSize());
|
||||
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.GetElementSpaceSize());
|
||||
DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.GetElementSpaceSize());
|
||||
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.GetElementSpaceSize());
|
||||
DeviceMem a1_device_buf(sizeof(XDataType) * a_scale_sorted.GetElementSpaceSize());
|
||||
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.GetElementSpaceSize());
|
||||
DeviceMem b1_device_buf(sizeof(XDataType) * b1_e_n_k.GetElementSpaceSize());
|
||||
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) * e_t_k_n_device_result.GetElementSpaceSize());
|
||||
|
||||
// A scale sorted
|
||||
for(int i = 0; i < sorted_size; i++)
|
||||
{
|
||||
int token_id = sorted_token_ids.mData[i] & 0x00FFFFFF;
|
||||
|
||||
for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; k++)
|
||||
{
|
||||
if(token_id == tokens)
|
||||
{
|
||||
a_scale_sorted(i, k) = ck::type_convert<XDataType>(0);
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_sorted(i, k) = a1_t_k(token_id, k);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// A/B scale shuffle
|
||||
preShuffleScaleBuffer<ck::is_same_v<A0Layout, Row>>(a_scale_sorted.mData.data(),
|
||||
a_scale_preshuffled.mData.data(),
|
||||
sorted_size,
|
||||
K / ScaleBlockSize);
|
||||
preShuffleScaleBuffer<ck::is_same_v<B0Layout, Col>>(b1_e_n_k.mData.data(),
|
||||
b_scale_preshuffled.mData.data(),
|
||||
N * 2 * experts,
|
||||
K / ScaleBlockSize);
|
||||
|
||||
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
|
||||
expert_ids_dev.ToDevice(expert_ids.mData.data());
|
||||
max_token_id_dev.ToDevice(max_token_id.mData.data());
|
||||
a0_device_buf.ToDevice(a0_t_k.mData.data());
|
||||
a1_device_buf.ToDevice(a_scale_preshuffled.mData.data());
|
||||
b1_device_buf.ToDevice(b_scale_preshuffled.mData.data());
|
||||
d2_device_buf.ToDevice(d2_e_n.mData.data());
|
||||
e_device_buf.ToDevice(e_t_k_n_device_result.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto device_op = DeviceOpInstance{};
|
||||
|
||||
preShuffleBuffer(b0_e_n_k.mData.data(),
|
||||
b0_preshuffled.mData.data(),
|
||||
N * 2 * experts,
|
||||
K,
|
||||
device_op.GetPreShuffleParameters());
|
||||
|
||||
b0_device_buf.ToDevice(b0_preshuffled.mData.data());
|
||||
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
auto argument = device_op.MakeArgument(
|
||||
sorted_token_ids_dev.GetDeviceBuffer(),
|
||||
expert_ids_dev.GetDeviceBuffer(),
|
||||
max_token_id_dev.GetDeviceBuffer(),
|
||||
a0_device_buf.GetDeviceBuffer(),
|
||||
a1_device_buf.GetDeviceBuffer(),
|
||||
b0_device_buf.GetDeviceBuffer(),
|
||||
b1_device_buf.GetDeviceBuffer(),
|
||||
std::array<const void*, NumDTensor>{nullptr, nullptr, d2_device_buf.GetDeviceBuffer()},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
tokens,
|
||||
topk,
|
||||
sorted_size,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
Scale_Stride_AM,
|
||||
StrideB,
|
||||
Scale_Stride_BN,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"))
|
||||
{
|
||||
std::cout << "This kernel support gfx942 and gfx950 only" << std::endl;
|
||||
}
|
||||
|
||||
if(time_kernel)
|
||||
{
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
std::size_t flop =
|
||||
// FMA * tokens * N * (Gate+Up) * topk * K +
|
||||
// FMA * tokens * N * (Gate+Up) * topk * (K/BlockScale)
|
||||
std::size_t(2) * tokens * N * 2 * topk * K +
|
||||
std::size_t(2) * tokens * N * 2 * topk * K / ScaleBlockSize;
|
||||
|
||||
std::size_t num_btype = sizeof(A0DataType) / 2 * tokens * topk * K +
|
||||
sizeof(B0DataType) / 2 * K * N * 2 * experts +
|
||||
sizeof(XDataType) * tokens * topk * K / ScaleBlockSize +
|
||||
sizeof(XDataType) * K / ScaleBlockSize * N * 2 * experts +
|
||||
sizeof(EDataType) * tokens * topk * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << device_op.GetTypeString() << std::endl;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1});
|
||||
|
||||
Tensor<float> c_t_k_n({tokens, topk, N}, {topk * N, N, 1});
|
||||
|
||||
using ReferenceGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceMoeMXGemm1<A0DataType,
|
||||
XDataType,
|
||||
B0DataType,
|
||||
XDataType,
|
||||
float, // CShuffleDataType,
|
||||
D2DataType,
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
ActOP,
|
||||
MulRoutedWeight>;
|
||||
auto ref_moe_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_moe_gemm.MakeInvoker();
|
||||
|
||||
auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids,
|
||||
expert_ids,
|
||||
max_token_id,
|
||||
MPerBlock,
|
||||
a0_t_k,
|
||||
a1_t_k,
|
||||
b0_e_n_k,
|
||||
b1_e_n_k,
|
||||
d2_e_n,
|
||||
c_t_k_n,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
PassThrough{});
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
for(int m = 0; m < valid_size; ++m)
|
||||
{
|
||||
const int fuse_t = sorted_token_ids.mData[m];
|
||||
const int t = fuse_t & 0xffffff;
|
||||
const int topk_id = (fuse_t & 0xff000000) >> 24;
|
||||
|
||||
if(t >= tokens)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
e_t_k_n_host_result(t, topk_id, n) =
|
||||
ck::type_convert<EDataType>(c_t_k_n(t, topk_id, n));
|
||||
}
|
||||
}
|
||||
|
||||
e_device_buf.FromDevice(e_t_k_n_device_result.mData.data());
|
||||
|
||||
auto status =
|
||||
ck::utils::check_err(
|
||||
e_t_k_n_device_result, e_t_k_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1)
|
||||
? 0
|
||||
: 1;
|
||||
if(status == 0)
|
||||
{
|
||||
printf("Validation Pass.\n");
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
542
example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp
Normal file
542
example/67_gemm_microscaling/moe_gemm2_xdl_mx_fp4.cpp
Normal file
@@ -0,0 +1,542 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm2.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/fill.hpp"
|
||||
#include "ck/utility/blkgemmpipe_scheduler.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F4 = ck::f4x2_pk_t;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
using XDataType = ck::e8m0_bexp_t;
|
||||
using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using A0DataType = F4;
|
||||
using A1DataType = XPackedDataType;
|
||||
using B0DataType = F4;
|
||||
using B1DataType = XPackedDataType;
|
||||
using EDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F16;
|
||||
using D0DataType = F32;
|
||||
using D1DataType = F32;
|
||||
using D2DataType = F32;
|
||||
using DsDataType = ck::Tuple<D0DataType, D1DataType, D2DataType>;
|
||||
|
||||
using A0Layout = Row;
|
||||
using B0Layout = Col;
|
||||
using ELayout = Row;
|
||||
using D0Layout = Row;
|
||||
using D1Layout = Col;
|
||||
using D2Layout = ELayout;
|
||||
using DsLayout = ck::Tuple<D0Layout, D1Layout, D2Layout>;
|
||||
|
||||
// d0: ascale, d1: bscale, d2:expert weight
|
||||
struct MulABScaleExpertWeight
|
||||
{
|
||||
template <typename E, typename C, typename D0, typename D1, typename D2>
|
||||
__host__ __device__ constexpr void
|
||||
operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const;
|
||||
// for real kernel use
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<EDataType, F16, float, float, float>(
|
||||
EDataType& e, const F16& c, const float& d0, const float& d1, const float& d2) const
|
||||
{
|
||||
(void)d0;
|
||||
(void)d1;
|
||||
(void)d2;
|
||||
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
// for reference cpu
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<float, float, float, float, float>(
|
||||
float& e, const float& c, const float& d0, const float& d1, const float& d2) const
|
||||
{
|
||||
// for reference cpu
|
||||
e = ck::type_convert<EDataType>(c * d0 * d1 * d2);
|
||||
}
|
||||
};
|
||||
|
||||
using CDEElementOp = MulABScaleExpertWeight;
|
||||
|
||||
// A, B Scale preshuffle
|
||||
template <bool KLast>
|
||||
void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K)
|
||||
{
|
||||
int MNXdlPack = 2;
|
||||
int KXdlPack = 2;
|
||||
|
||||
int XdlMNThread = 16;
|
||||
int XdlKThread = 64 / XdlMNThread;
|
||||
|
||||
int K0 = K / KXdlPack / XdlKThread; // KRepeat
|
||||
|
||||
// The 4 16x128 building blocks will be packed into 1 32x256 for F4
|
||||
// The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4
|
||||
|
||||
// unfold the MN32xK(256/32) scale buffer
|
||||
// 4 16 2 2
|
||||
// To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack
|
||||
// Then, MNRepeat->KRepeat
|
||||
|
||||
for(int n = 0; n < MN; ++n)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat
|
||||
int tempn = n % (XdlMNThread * MNXdlPack);
|
||||
int n1 = tempn % XdlMNThread; // i XdlMNThread
|
||||
int n2 = tempn / XdlMNThread; // i MNXdlPack
|
||||
|
||||
int k0 = k / (XdlKThread * KXdlPack); // i KRepeat
|
||||
int tempk = k % (XdlKThread * KXdlPack);
|
||||
int k1 = tempk % XdlKThread; // i XdlKThread
|
||||
int k2 = tempk / XdlKThread; // i KXdlPack
|
||||
|
||||
int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 +
|
||||
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
|
||||
k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack +
|
||||
k2 * MNXdlPack + n2;
|
||||
// src[n * K + k] = ck::type_convert<ck::e8m0_bexp_t>(static_cast<float>(powf(2.0f, n2 +
|
||||
// k2 * MNXdlPack)));
|
||||
if constexpr(KLast)
|
||||
dst[outputIndex] = src[n * K + k];
|
||||
else
|
||||
dst[outputIndex] = src[k * MN + n];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = MulABScaleExpertWeight;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
constexpr ck::index_t DataPackedSize = 2; // Packed representation of data
|
||||
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
|
||||
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
|
||||
|
||||
static constexpr ck::index_t MPerBlock = 128;
|
||||
static constexpr bool MulRoutedWeight = true;
|
||||
|
||||
// clang-format off
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMX<
|
||||
A0Layout, B0Layout, DsLayout, ELayout,
|
||||
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
ScaleBlockSize, 256,
|
||||
MPerBlock, 128, KPerBlock,
|
||||
16, 16,
|
||||
16, 16,
|
||||
4, 4,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
|
||||
2, 4, S<1, 4, 1, 64>, S<2, 1, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>;
|
||||
// clang-format on
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
|
||||
// per expert:
|
||||
// GEMM shape
|
||||
constexpr ck::index_t sorted_tile_num = 13;
|
||||
constexpr ck::index_t valid_tile_num = sorted_tile_num;
|
||||
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
|
||||
ck::index_t valid_size = valid_tile_num * MPerBlock;
|
||||
|
||||
ck::index_t N = 6144;
|
||||
ck::index_t K = 4096;
|
||||
ck::index_t experts = 8;
|
||||
ck::index_t tokens = 832;
|
||||
ck::index_t topk = 2;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
// use default case
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 7)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
N = std::stoi(argv[4]);
|
||||
K = std::stoi(argv[5]);
|
||||
tokens = std::stoi(argv[6]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 6: N, K, tokens\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
if(K % ScaleBlockSize != 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize.");
|
||||
};
|
||||
|
||||
ck::index_t StrideA = K;
|
||||
ck::index_t StrideB = K;
|
||||
ck::index_t StrideE = N;
|
||||
ck::index_t Scale_Stride_AM = (K + ScaleBlockSize - 1) / ScaleBlockSize;
|
||||
ck::index_t Scale_Stride_BN = (K + ScaleBlockSize - 1) / ScaleBlockSize;
|
||||
constexpr ck::index_t NumDTensor = DsDataType::Size();
|
||||
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0, 0};
|
||||
|
||||
ck::index_t KBatch = 1;
|
||||
|
||||
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
|
||||
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
|
||||
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1}));
|
||||
max_token_id.mData[0] = valid_size;
|
||||
// int eids[] = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 3, 3, 3};
|
||||
int eids[sorted_tile_num]{};
|
||||
for(int i = 0; i < sorted_tile_num; i++)
|
||||
{
|
||||
if(i < valid_tile_num)
|
||||
{
|
||||
eids[i] = (i * experts) / valid_tile_num;
|
||||
}
|
||||
else
|
||||
{
|
||||
eids[i] = 3;
|
||||
}
|
||||
}
|
||||
|
||||
for(int i = 0; i < sorted_tile_num; i++)
|
||||
{
|
||||
expert_ids.mData[i] = eids[i];
|
||||
}
|
||||
if(tokens * topk > valid_size)
|
||||
{
|
||||
printf("err config, tokens * topk > valid_size\n");
|
||||
exit(-1);
|
||||
}
|
||||
int token_per_tile = tokens * topk / valid_tile_num;
|
||||
int tokenid = 0;
|
||||
for(int i = 0; i < sorted_size; i++)
|
||||
{
|
||||
int tile_off = i % MPerBlock;
|
||||
if(tile_off < token_per_tile)
|
||||
{
|
||||
sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24);
|
||||
tokenid++;
|
||||
}
|
||||
else
|
||||
{
|
||||
sorted_token_ids.mData[i] = tokens;
|
||||
}
|
||||
}
|
||||
|
||||
expert_ids.savetxt("expert_ids.txt", "int");
|
||||
sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
|
||||
Tensor<A0DataType> a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1}));
|
||||
Tensor<XDataType> a1_t_k_k(
|
||||
HostTensorDescriptor({tokens, topk, (K + ScaleBlockSize - 1) / ScaleBlockSize},
|
||||
{(topk * Scale_Stride_AM), Scale_Stride_AM, 1}));
|
||||
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
|
||||
Tensor<XDataType> b1_e_n_k(
|
||||
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N},
|
||||
{(N * Scale_Stride_BN), 1, Scale_Stride_BN}));
|
||||
|
||||
// A, B Scale preshuffle
|
||||
Tensor<XDataType> a_scale_sorted(HostTensorDescriptor(
|
||||
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
|
||||
Tensor<XDataType> a_scale_preshuffled(HostTensorDescriptor(
|
||||
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
|
||||
Tensor<XDataType> b_scale_preshuffled(
|
||||
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N},
|
||||
{N * Scale_Stride_BN, 1, Scale_Stride_BN}));
|
||||
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
|
||||
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1}));
|
||||
Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1}));
|
||||
|
||||
e_t_n_device_result.SetZero();
|
||||
std::cout << "a0_t_k_k: " << a0_t_k_k.mDesc << std::endl;
|
||||
std::cout << "a1_t_k_k: " << a1_t_k_k.mDesc << std::endl;
|
||||
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
|
||||
std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl;
|
||||
std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl;
|
||||
std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-1, 1});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-1, 1});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0, 1.0});
|
||||
break;
|
||||
case 2:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 3:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 4:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 5.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 5:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 6:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 7:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 8:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
default:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
|
||||
}
|
||||
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.GetElementSpaceSize());
|
||||
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.GetElementSpaceSize());
|
||||
DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.GetElementSpaceSize());
|
||||
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k_k.GetElementSpaceSize());
|
||||
DeviceMem a1_device_buf(sizeof(XDataType) * a_scale_sorted.GetElementSpaceSize());
|
||||
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.GetElementSpaceSize());
|
||||
DeviceMem b1_device_buf(sizeof(XDataType) * b1_e_n_k.GetElementSpaceSize());
|
||||
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.GetElementSpaceSize());
|
||||
// d2_e_n.savetxt("weight.txt", "int");
|
||||
|
||||
// A scale sorted
|
||||
for(int i = 0; i < sorted_size; i++)
|
||||
{
|
||||
int token_id = sorted_token_ids.mData[i] & 0x00FFFFFF;
|
||||
int topk_id = (sorted_token_ids.mData[i] >> 24) & 0x000000FF;
|
||||
|
||||
for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; k++)
|
||||
{
|
||||
if(token_id == tokens)
|
||||
{
|
||||
a_scale_sorted(i, k) = ck::type_convert<XDataType>(0);
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_sorted(i, k) = a1_t_k_k(token_id, topk_id, k);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
preShuffleScaleBuffer<ck::is_same_v<A0Layout, Row>>(a_scale_sorted.mData.data(),
|
||||
a_scale_preshuffled.mData.data(),
|
||||
sorted_size,
|
||||
K / ScaleBlockSize);
|
||||
preShuffleScaleBuffer<ck::is_same_v<B0Layout, Col>>(
|
||||
b1_e_n_k.mData.data(), b_scale_preshuffled.mData.data(), N * experts, K / ScaleBlockSize);
|
||||
|
||||
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
|
||||
expert_ids_dev.ToDevice(expert_ids.mData.data());
|
||||
max_token_id_dev.ToDevice(max_token_id.mData.data());
|
||||
a0_device_buf.ToDevice(a0_t_k_k.mData.data());
|
||||
b0_device_buf.ToDevice(b0_e_n_k.mData.data());
|
||||
a1_device_buf.ToDevice(a_scale_preshuffled.mData.data());
|
||||
b1_device_buf.ToDevice(b_scale_preshuffled.mData.data());
|
||||
d2_device_buf.ToDevice(d2_e_n.mData.data());
|
||||
e_device_buf.ToDevice(e_t_n_device_result.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto device_op = DeviceOpInstance{};
|
||||
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
auto argument = device_op.MakeArgument(
|
||||
sorted_token_ids_dev.GetDeviceBuffer(),
|
||||
expert_ids_dev.GetDeviceBuffer(),
|
||||
max_token_id_dev.GetDeviceBuffer(),
|
||||
a0_device_buf.GetDeviceBuffer(),
|
||||
a1_device_buf.GetDeviceBuffer(),
|
||||
b0_device_buf.GetDeviceBuffer(),
|
||||
b1_device_buf.GetDeviceBuffer(),
|
||||
std::array<const void*, NumDTensor>{nullptr, nullptr, d2_device_buf.GetDeviceBuffer()},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
tokens,
|
||||
topk,
|
||||
sorted_size,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
Scale_Stride_AM,
|
||||
StrideB,
|
||||
Scale_Stride_BN,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"))
|
||||
{
|
||||
std::cout << "This kernel support gfx942 and gfx950 only" << std::endl;
|
||||
}
|
||||
|
||||
if(time_kernel)
|
||||
{
|
||||
// not result correct here because output buf not setzero
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
// FMA * tokens * N * topk * K +
|
||||
// FMA * tokens * N * topk * (K/BlockScale)
|
||||
std::size_t flop = std::size_t(2) * tokens * topk * N * K +
|
||||
std::size_t(2) * tokens * topk * N * K / ScaleBlockSize;
|
||||
|
||||
std::size_t num_btype =
|
||||
sizeof(A0DataType) / 2 * tokens * K * topk + sizeof(B0DataType) / 2 * K * N * experts +
|
||||
sizeof(XDataType) * tokens * topk * K / ScaleBlockSize +
|
||||
sizeof(XDataType) * K / ScaleBlockSize * N * experts + sizeof(EDataType) * tokens * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << device_op.GetTypeString() << std::endl;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
// gemm2 use atomic, so need to reinit outputs
|
||||
e_device_buf.ToDevice(e_t_n_device_result.mData.data());
|
||||
invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1});
|
||||
|
||||
Tensor<float> c_t_n({tokens, N});
|
||||
|
||||
using ReferenceGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceMoeMXGemm2<A0DataType,
|
||||
XDataType,
|
||||
B0DataType,
|
||||
XDataType,
|
||||
D2DataType,
|
||||
float, // using float for Cshuffle type
|
||||
// in reference
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
CDEElementOp,
|
||||
MulRoutedWeight,
|
||||
float,
|
||||
float>;
|
||||
|
||||
auto ref_moe_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_moe_gemm.MakeInvoker();
|
||||
auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids,
|
||||
expert_ids,
|
||||
max_token_id,
|
||||
MPerBlock,
|
||||
a0_t_k_k,
|
||||
a1_t_k_k,
|
||||
b0_e_n_k,
|
||||
b1_e_n_k,
|
||||
d2_e_n, // topk weights
|
||||
c_t_n,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
cde_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
for(int t = 0; t < tokens; ++t)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
e_t_n_host_result(t, n) = ck::type_convert<EDataType>(c_t_n(t, n));
|
||||
}
|
||||
}
|
||||
|
||||
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
|
||||
|
||||
return ck::utils::check_err(
|
||||
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2)
|
||||
? 0
|
||||
: 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -158,7 +158,7 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
|
||||
4, 4,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
|
||||
2, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
|
||||
2, 4, S<1, 4, 1, 64>, S<2, 1, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -0,0 +1,584 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bpreshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm2.hpp"
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/fill.hpp"
|
||||
#include "ck/utility/blkgemmpipe_scheduler.hpp"
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using F4 = ck::f4x2_pk_t;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
using XDataType = ck::e8m0_bexp_t;
|
||||
using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t
|
||||
using I64 = int64_t;
|
||||
|
||||
using Row = ck::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using A0DataType = F4;
|
||||
using A1DataType = XPackedDataType;
|
||||
using B0DataType = F4;
|
||||
using B1DataType = XPackedDataType;
|
||||
using EDataType = F16;
|
||||
using AccDataType = F32;
|
||||
using CShuffleDataType = F16;
|
||||
using D0DataType = F32;
|
||||
using D1DataType = F32;
|
||||
using D2DataType = F32;
|
||||
using DsDataType = ck::Tuple<D0DataType, D1DataType, D2DataType>;
|
||||
|
||||
using A0Layout = Row;
|
||||
using B0Layout = Col;
|
||||
using ELayout = Row;
|
||||
using D0Layout = Row;
|
||||
using D1Layout = Col;
|
||||
using D2Layout = ELayout;
|
||||
using DsLayout = ck::Tuple<D0Layout, D1Layout, D2Layout>;
|
||||
|
||||
// d0: ascale, d1: bscale, d2:expert weight
|
||||
struct MulABScaleExpertWeight
|
||||
{
|
||||
template <typename E, typename C, typename D0, typename D1, typename D2>
|
||||
__host__ __device__ constexpr void
|
||||
operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const;
|
||||
// for real kernel use
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<EDataType, F16, float, float, float>(
|
||||
EDataType& e, const F16& c, const float& d0, const float& d1, const float& d2) const
|
||||
{
|
||||
(void)d0;
|
||||
(void)d1;
|
||||
(void)d2;
|
||||
|
||||
e = ck::type_convert<EDataType>(c);
|
||||
}
|
||||
// for reference cpu
|
||||
template <>
|
||||
__host__ __device__ constexpr void operator()<float, float, float, float, float>(
|
||||
float& e, const float& c, const float& d0, const float& d1, const float& d2) const
|
||||
{
|
||||
// for reference cpu
|
||||
e = ck::type_convert<EDataType>(c * d0 * d1 * d2);
|
||||
}
|
||||
};
|
||||
|
||||
using CDEElementOp = MulABScaleExpertWeight;
|
||||
|
||||
// B preshuffle
|
||||
void preShuffleBuffer(const F4* src, F4* dst, int N, int K, int NXdl)
|
||||
{
|
||||
int KPack = 16;
|
||||
int NLane = NXdl;
|
||||
int KLane = 64 / NLane;
|
||||
int K_pk = K / 2;
|
||||
int K0 = K_pk / (KLane * KPack);
|
||||
// K -> K0 KLane KPack
|
||||
// N -> N0 NLane
|
||||
// N, K -> N0 K0 KLane NLane KPack
|
||||
I64 tempk;
|
||||
for(I64 n = 0; n < N; ++n)
|
||||
{
|
||||
for(I64 k = 0; k < K_pk; ++k)
|
||||
{
|
||||
I64 n0 = n / NLane;
|
||||
I64 n1 = n % NLane;
|
||||
|
||||
I64 k0 = k / (KLane * KPack);
|
||||
tempk = k % (KLane * KPack);
|
||||
I64 k1 = tempk / KPack;
|
||||
I64 k2 = tempk % KPack;
|
||||
|
||||
I64 outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
|
||||
k1 * KPack * NLane + n1 * KPack + k2;
|
||||
|
||||
dst[outputIndex] = src[n * K_pk + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// A, B Scale preshuffle
|
||||
template <bool KLast>
|
||||
void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K)
|
||||
{
|
||||
int MNXdlPack = 2;
|
||||
int KXdlPack = 2;
|
||||
|
||||
int XdlMNThread = 16;
|
||||
int XdlKThread = 64 / XdlMNThread;
|
||||
|
||||
int K0 = K / KXdlPack / XdlKThread; // KRepeat
|
||||
|
||||
// The 4 16x128 building blocks will be packed into 1 32x256 for F4
|
||||
// The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4
|
||||
|
||||
// unfold the MN32xK(256/32) scale buffer
|
||||
// 4 16 2 2
|
||||
// To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack
|
||||
// Then, MNRepeat->KRepeat
|
||||
|
||||
for(int n = 0; n < MN; ++n)
|
||||
{
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat
|
||||
int tempn = n % (XdlMNThread * MNXdlPack);
|
||||
int n1 = tempn % XdlMNThread; // i XdlMNThread
|
||||
int n2 = tempn / XdlMNThread; // i MNXdlPack
|
||||
|
||||
int k0 = k / (XdlKThread * KXdlPack); // i KRepeat
|
||||
int tempk = k % (XdlKThread * KXdlPack);
|
||||
int k1 = tempk % XdlKThread; // i XdlKThread
|
||||
int k2 = tempk / XdlKThread; // i KXdlPack
|
||||
|
||||
int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 +
|
||||
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
|
||||
k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack +
|
||||
k2 * MNXdlPack + n2;
|
||||
// src[n * K + k] = ck::type_convert<ck::e8m0_bexp_t>(static_cast<float>(powf(2.0f, n2 +
|
||||
// k2 * MNXdlPack)));
|
||||
if constexpr(KLast)
|
||||
dst[outputIndex] = src[n * K + k];
|
||||
else
|
||||
dst[outputIndex] = src[k * MN + n];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = MulABScaleExpertWeight;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
constexpr ck::index_t DataPackedSize = 2; // Packed representation of data
|
||||
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
|
||||
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
|
||||
|
||||
static constexpr ck::index_t MPerBlock = 128;
|
||||
static constexpr bool MulRoutedWeight = true;
|
||||
|
||||
// clang-format off
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMXBPreShuffle<
|
||||
A0Layout, B0Layout, DsLayout, ELayout,
|
||||
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
|
||||
AElementOp, BElementOp, CDEElementOp, GemmSpec,
|
||||
ScaleBlockSize, 256,
|
||||
MPerBlock, 128, KPerBlock,
|
||||
16, 16,
|
||||
16, 16,
|
||||
8, 2,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
|
||||
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
|
||||
2, 2, S<1, 4, 1, 64>, S<2, 1, 1, 1>,
|
||||
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>;
|
||||
// clang-format on
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
bool do_verification = true;
|
||||
int init_method = 1;
|
||||
bool time_kernel = true;
|
||||
|
||||
// per expert:
|
||||
// GEMM shape
|
||||
constexpr ck::index_t sorted_tile_num = 13;
|
||||
constexpr ck::index_t valid_tile_num = 13;
|
||||
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
|
||||
ck::index_t valid_size = valid_tile_num * MPerBlock;
|
||||
|
||||
ck::index_t N = 6144;
|
||||
ck::index_t K = 4096;
|
||||
ck::index_t experts = 8;
|
||||
ck::index_t tokens = 832;
|
||||
ck::index_t topk = 2;
|
||||
|
||||
if(argc == 1)
|
||||
{
|
||||
// use default case
|
||||
}
|
||||
else if(argc == 4)
|
||||
{
|
||||
// use default case
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
}
|
||||
else if(argc == 7)
|
||||
{
|
||||
do_verification = std::stoi(argv[1]);
|
||||
init_method = std::stoi(argv[2]);
|
||||
time_kernel = std::stoi(argv[3]);
|
||||
N = std::stoi(argv[4]);
|
||||
K = std::stoi(argv[5]);
|
||||
tokens = std::stoi(argv[6]);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("arg1: verification (0=no, 1=yes)\n");
|
||||
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
|
||||
printf("arg3: time kernel (0=no, 1=yes)\n");
|
||||
printf("arg4 to 6: N, K, tokens\n");
|
||||
exit(0);
|
||||
}
|
||||
|
||||
if(K % ScaleBlockSize != 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize.");
|
||||
};
|
||||
|
||||
ck::index_t StrideA = K;
|
||||
ck::index_t StrideB = K;
|
||||
ck::index_t StrideE = N;
|
||||
ck::index_t Scale_Stride_AM = (K + ScaleBlockSize - 1) / ScaleBlockSize;
|
||||
ck::index_t Scale_Stride_BN = (K + ScaleBlockSize - 1) / ScaleBlockSize;
|
||||
constexpr ck::index_t NumDTensor = DsDataType::Size();
|
||||
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0, 0};
|
||||
|
||||
ck::index_t KBatch = 1;
|
||||
|
||||
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
|
||||
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
|
||||
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1}));
|
||||
max_token_id.mData[0] = valid_size;
|
||||
// int eids[] = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 3, 3, 3};
|
||||
int eids[sorted_tile_num]{};
|
||||
for(int i = 0; i < sorted_tile_num; i++)
|
||||
{
|
||||
if(i < valid_tile_num)
|
||||
{
|
||||
eids[i] = (i * experts) / valid_tile_num;
|
||||
}
|
||||
else
|
||||
{
|
||||
eids[i] = 3;
|
||||
}
|
||||
}
|
||||
|
||||
for(int i = 0; i < sorted_tile_num; i++)
|
||||
{
|
||||
expert_ids.mData[i] = eids[i];
|
||||
}
|
||||
if(tokens * topk > valid_size)
|
||||
{
|
||||
printf("err config, tokens * topk > valid_size\n");
|
||||
exit(-1);
|
||||
}
|
||||
int token_per_tile = tokens * topk / valid_tile_num;
|
||||
int tokenid = 0;
|
||||
for(int i = 0; i < sorted_size; i++)
|
||||
{
|
||||
int tile_off = i % MPerBlock;
|
||||
if(tile_off < token_per_tile)
|
||||
{
|
||||
sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24);
|
||||
tokenid++;
|
||||
}
|
||||
else
|
||||
{
|
||||
sorted_token_ids.mData[i] = tokens;
|
||||
}
|
||||
}
|
||||
|
||||
expert_ids.savetxt("expert_ids.txt", "int");
|
||||
sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
|
||||
Tensor<A0DataType> a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1}));
|
||||
Tensor<XDataType> a1_t_k_k(
|
||||
HostTensorDescriptor({tokens, topk, (K + ScaleBlockSize - 1) / ScaleBlockSize},
|
||||
{(topk * Scale_Stride_AM), Scale_Stride_AM, 1}));
|
||||
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
|
||||
Tensor<XDataType> b1_e_n_k(
|
||||
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N},
|
||||
{(N * Scale_Stride_BN), 1, Scale_Stride_BN}));
|
||||
// B preshuffle
|
||||
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
|
||||
|
||||
// A, B Scale preshuffle
|
||||
Tensor<XDataType> a_scale_sorted(HostTensorDescriptor(
|
||||
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
|
||||
Tensor<XDataType> a_scale_preshuffled(HostTensorDescriptor(
|
||||
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
|
||||
Tensor<XDataType> b_scale_preshuffled(
|
||||
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N},
|
||||
{N * Scale_Stride_BN, 1, Scale_Stride_BN}));
|
||||
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
|
||||
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1}));
|
||||
Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1}));
|
||||
|
||||
e_t_n_device_result.SetZero();
|
||||
std::cout << "a0_t_k_k: " << a0_t_k_k.mDesc << std::endl;
|
||||
std::cout << "a1_t_k_k: " << a1_t_k_k.mDesc << std::endl;
|
||||
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
|
||||
std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl;
|
||||
std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl;
|
||||
std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl;
|
||||
|
||||
switch(init_method)
|
||||
{
|
||||
case 0: break;
|
||||
case 1:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-1, 1});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-1, 1});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0, 1.0});
|
||||
break;
|
||||
case 2:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 3:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 4:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 5.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 5:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 6:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 7:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
case 8:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
|
||||
break;
|
||||
default:
|
||||
a0_t_k_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
|
||||
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
|
||||
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
|
||||
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
|
||||
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
|
||||
}
|
||||
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.GetElementSpaceSize());
|
||||
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.GetElementSpaceSize());
|
||||
DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.GetElementSpaceSize());
|
||||
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k_k.GetElementSpaceSize());
|
||||
DeviceMem a1_device_buf(sizeof(XDataType) * a_scale_sorted.GetElementSpaceSize());
|
||||
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.GetElementSpaceSize());
|
||||
DeviceMem b1_device_buf(sizeof(XDataType) * b1_e_n_k.GetElementSpaceSize());
|
||||
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.GetElementSpaceSize());
|
||||
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.GetElementSpaceSize());
|
||||
|
||||
// A scale sorted
|
||||
for(int i = 0; i < sorted_size; i++)
|
||||
{
|
||||
int token_id = sorted_token_ids.mData[i] & 0x00FFFFFF;
|
||||
int topk_id = (sorted_token_ids.mData[i] >> 24) & 0x000000FF;
|
||||
|
||||
for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; k++)
|
||||
{
|
||||
if(token_id == tokens)
|
||||
{
|
||||
a_scale_sorted(i, k) = ck::type_convert<XDataType>(0);
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_sorted(i, k) = a1_t_k_k(token_id, topk_id, k);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// A, B Scale preshuffle
|
||||
preShuffleScaleBuffer<ck::is_same_v<A0Layout, Row>>(a_scale_sorted.mData.data(),
|
||||
a_scale_preshuffled.mData.data(),
|
||||
sorted_size,
|
||||
K / ScaleBlockSize);
|
||||
preShuffleScaleBuffer<ck::is_same_v<B0Layout, Col>>(
|
||||
b1_e_n_k.mData.data(), b_scale_preshuffled.mData.data(), N * experts, K / ScaleBlockSize);
|
||||
|
||||
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
|
||||
expert_ids_dev.ToDevice(expert_ids.mData.data());
|
||||
max_token_id_dev.ToDevice(max_token_id.mData.data());
|
||||
a0_device_buf.ToDevice(a0_t_k_k.mData.data());
|
||||
a1_device_buf.ToDevice(a_scale_preshuffled.mData.data());
|
||||
b1_device_buf.ToDevice(b_scale_preshuffled.mData.data());
|
||||
d2_device_buf.ToDevice(d2_e_n.mData.data());
|
||||
e_device_buf.ToDevice(e_t_n_device_result.mData.data());
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
|
||||
// do GEMM
|
||||
auto device_op = DeviceOpInstance{};
|
||||
|
||||
preShuffleBuffer(b0_e_n_k.mData.data(),
|
||||
b0_preshuffled.mData.data(),
|
||||
N * experts,
|
||||
K,
|
||||
device_op.GetPreShuffleParameters());
|
||||
|
||||
b0_device_buf.ToDevice(b0_preshuffled.mData.data());
|
||||
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
auto argument = device_op.MakeArgument(
|
||||
sorted_token_ids_dev.GetDeviceBuffer(),
|
||||
expert_ids_dev.GetDeviceBuffer(),
|
||||
max_token_id_dev.GetDeviceBuffer(),
|
||||
a0_device_buf.GetDeviceBuffer(),
|
||||
a1_device_buf.GetDeviceBuffer(),
|
||||
b0_device_buf.GetDeviceBuffer(),
|
||||
b1_device_buf.GetDeviceBuffer(),
|
||||
std::array<const void*, NumDTensor>{nullptr, nullptr, d2_device_buf.GetDeviceBuffer()},
|
||||
e_device_buf.GetDeviceBuffer(),
|
||||
tokens,
|
||||
topk,
|
||||
sorted_size,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
Scale_Stride_AM,
|
||||
StrideB,
|
||||
Scale_Stride_BN,
|
||||
StrideDs,
|
||||
StrideE,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! device_gemm with the specified compilation parameters does "
|
||||
"not support this GEMM problem");
|
||||
}
|
||||
|
||||
if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"))
|
||||
{
|
||||
std::cout << "This kernel support gfx942 and gfx950 only" << std::endl;
|
||||
}
|
||||
|
||||
if(time_kernel)
|
||||
{
|
||||
// not result correct here because output buf not setzero
|
||||
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
|
||||
|
||||
// FMA * tokens * N * topk * K +
|
||||
// FMA * tokens * N * topk * (K/BlockScale)
|
||||
std::size_t flop = std::size_t(2) * tokens * topk * N * K +
|
||||
std::size_t(2) * tokens * topk * N * K / ScaleBlockSize;
|
||||
|
||||
std::size_t num_btype =
|
||||
sizeof(A0DataType) / 2 * tokens * K * topk + sizeof(B0DataType) / 2 * K * N * experts +
|
||||
sizeof(XDataType) * tokens * topk * K / ScaleBlockSize +
|
||||
sizeof(XDataType) * K / ScaleBlockSize * N * experts + sizeof(EDataType) * tokens * N;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << device_op.GetTypeString() << std::endl;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
// gemm2 use atomic, so need to reinit outputs
|
||||
e_device_buf.ToDevice(e_t_n_device_result.mData.data());
|
||||
invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1});
|
||||
|
||||
Tensor<float> c_t_n({tokens, N});
|
||||
|
||||
using ReferenceGemmInstance =
|
||||
ck::tensor_operation::host::ReferenceMoeMXGemm2<A0DataType,
|
||||
XDataType,
|
||||
B0DataType,
|
||||
XDataType,
|
||||
D2DataType,
|
||||
float, // using float for Cshuffle type
|
||||
// in reference
|
||||
AccDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
CDEElementOp,
|
||||
MulRoutedWeight,
|
||||
float,
|
||||
float>;
|
||||
|
||||
auto ref_moe_gemm = ReferenceGemmInstance{};
|
||||
auto ref_invoker = ref_moe_gemm.MakeInvoker();
|
||||
auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids,
|
||||
expert_ids,
|
||||
max_token_id,
|
||||
MPerBlock,
|
||||
a0_t_k_k,
|
||||
a1_t_k_k,
|
||||
b0_e_n_k,
|
||||
b1_e_n_k,
|
||||
d2_e_n, // topk weights
|
||||
c_t_n,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
cde_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
for(int t = 0; t < tokens; ++t)
|
||||
{
|
||||
for(int n = 0; n < N; ++n)
|
||||
{
|
||||
e_t_n_host_result(t, n) = ck::type_convert<EDataType>(c_t_n(t, n));
|
||||
}
|
||||
}
|
||||
|
||||
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
|
||||
|
||||
return ck::utils::check_err(
|
||||
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2)
|
||||
? 0
|
||||
: 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
@@ -1,919 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Naive pipeline with lowest resource request per WGP
|
||||
// GlobalPrefetchStages: 2
|
||||
// LocalPreFillStages: 1
|
||||
// LocalPreFetchStages: 1
|
||||
// LocalSharedMemoryBuffer: 1
|
||||
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
|
||||
index_t ThreadBlockSize,
|
||||
index_t ScaleBlockSize,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat, // MXdlPerWave
|
||||
index_t NRepeat, // NXdlPerWave
|
||||
index_t KPack>
|
||||
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v1
|
||||
{
|
||||
};
|
||||
|
||||
template <index_t ThreadBlockSize,
|
||||
index_t ScaleBlockSize,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat, // MXdlPerWave
|
||||
index_t NRepeat, // NXdlPerWave
|
||||
index_t KPack>
|
||||
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v1<
|
||||
BlockGemmPipelineScheduler::Intrawave,
|
||||
ThreadBlockSize,
|
||||
ScaleBlockSize,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack> : BlockwiseGemmXdlops_mx_pipeline_base<ThreadBlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
|
||||
{
|
||||
|
||||
using Base = BlockwiseGemmXdlops_mx_pipeline_base<ThreadBlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::KRepeat;
|
||||
using Base::MWaves;
|
||||
using Base::NWaves;
|
||||
using Base::WaveSize;
|
||||
using Base::xdlops_gemm;
|
||||
|
||||
using Base::CalculateCThreadOriginDataIndex;
|
||||
using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::GetCThreadBuffer;
|
||||
using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::GetWaveIdx;
|
||||
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
|
||||
using Base::a_block_desc_m0_m1_m2_k;
|
||||
using Base::b_block_desc_n0_n1_n2_k;
|
||||
|
||||
using Base::AMmaKStride;
|
||||
using Base::BMmaKStride;
|
||||
using Base::KThreadChunk;
|
||||
|
||||
using Base::APackedSize;
|
||||
using Base::BPackedSize;
|
||||
using Base::ComputePackedSize;
|
||||
|
||||
using AccType = typename Base::AccType;
|
||||
using Tuple4 = typename Base::Tuple4;
|
||||
using ComputeTypeA = typename Base::ComputeTypeA;
|
||||
using ComputeTypeB = typename Base::ComputeTypeB;
|
||||
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 2;
|
||||
|
||||
template <typename TileDesc_M0_M1_M2_K>
|
||||
__host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&)
|
||||
{
|
||||
constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
|
||||
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
|
||||
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
|
||||
constexpr index_t K2 = KPack;
|
||||
constexpr index_t K1 = 64 / NPerXDL;
|
||||
constexpr index_t K0 = KRepeat;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
TileDesc_M0_M1_M2_K{},
|
||||
make_tuple(
|
||||
make_pass_through_transform(Number<M0>{}),
|
||||
make_pass_through_transform(Number<M1>{}),
|
||||
make_pass_through_transform(Number<M2>{}),
|
||||
make_unmerge_transform(make_tuple(Number<K0>{}, Number<K1>{}, Number<K2>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}));
|
||||
}
|
||||
|
||||
static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 =
|
||||
MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k);
|
||||
|
||||
static constexpr auto ScalesPerKBlockSize =
|
||||
KPerBlock / ScaleBlockSize; // How many mx-vectors per K block
|
||||
|
||||
//> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run()
|
||||
static constexpr auto ScalesPerXdlopsRun = (KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize;
|
||||
|
||||
//> How many scales a thread must read to accommodate one call to xdlops_gemm.Run()
|
||||
static constexpr auto ScalesPerXdlopsRunPerThread =
|
||||
ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks;
|
||||
|
||||
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
TailNumber TailNum,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename CThreadBuffer,
|
||||
typename AScaleGridBuffer,
|
||||
typename AScaleGridDesc,
|
||||
typename AScaleThreadTransfer,
|
||||
typename BScaleGridBuffer,
|
||||
typename BScaleGridDesc,
|
||||
typename BScaleThreadTransfer>
|
||||
__device__ void Run(
|
||||
// ABlockCopy
|
||||
const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
// BBlockCopy
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
BBlockTransfer& b_blockwise_copy_up,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
const BGridBuffer& b_grid_buf_up,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
// CThread
|
||||
CThreadBuffer& c_thread_buf,
|
||||
CThreadBuffer& c_thread_buf_up,
|
||||
// A and B scales
|
||||
const AScaleGridDesc& a_scale_grid_desc,
|
||||
AScaleThreadTransfer& a_scale_thread_copy,
|
||||
const AScaleGridBuffer& a_scale_grid_buf,
|
||||
const BScaleGridDesc& b_scale_grid_desc,
|
||||
BScaleThreadTransfer& b_scale_thread_copy,
|
||||
BScaleThreadTransfer& b_scale_thread_copy_up,
|
||||
const BScaleGridBuffer& b_scale_grid_buf,
|
||||
const BScaleGridBuffer& b_scale_grid_buf_up,
|
||||
index_t num_loop) const
|
||||
{
|
||||
ignore = b_block_desc;
|
||||
ignore = b_block_buf;
|
||||
ignore = a_scale_grid_buf;
|
||||
ignore = b_scale_grid_buf;
|
||||
ignore = b_scale_grid_buf_up;
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
|
||||
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs_up;
|
||||
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
|
||||
|
||||
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
|
||||
a_scale_thread_desc.GetElementSpaceSize());
|
||||
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc.GetElementSpaceSize());
|
||||
|
||||
StaticallyIndexedArray<decltype(a_scale_thread_buf), Number<2>{}> a_scale_thread_bufs;
|
||||
StaticallyIndexedArray<decltype(b_scale_thread_buf), Number<2>{}> b_scale_thread_bufs;
|
||||
StaticallyIndexedArray<decltype(b_scale_thread_buf), Number<2>{}> b_scale_thread_bufs_up;
|
||||
|
||||
// Global prefetch A1 B1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I0));
|
||||
b_blockwise_copy_up.Run(b_grid_desc,
|
||||
b_grid_buf_up,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs_up(I0));
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Prefetch a_scales to buf 0
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(I0, I0, I0),
|
||||
a_scale_thread_bufs(I0));
|
||||
|
||||
// restore row id and advance to the next set of scales
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
make_multi_index(0, ScalesPerKBlockSize, 0));
|
||||
|
||||
// Prefetch b_scales to buf 0
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
constexpr auto b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
|
||||
auto b_scale_thread_buf_copy =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf_copy);
|
||||
|
||||
b_scale_thread_bufs(I0)(Number<b_scale_offset>{}) =
|
||||
b_scale_thread_buf_copy[Number<0>{}];
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
|
||||
auto b_scale_thread_buf_copy_up =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
b_scale_thread_copy_up.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf_up,
|
||||
b_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf_copy_up);
|
||||
|
||||
b_scale_thread_bufs_up(I0)(Number<b_scale_offset>{}) =
|
||||
b_scale_thread_buf_copy_up[Number<0>{}];
|
||||
b_scale_thread_copy_up.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
});
|
||||
});
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
|
||||
b_scale_thread_copy_up.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
|
||||
});
|
||||
|
||||
// restore col id and advance to the next set of scales
|
||||
// NWaves * NPerXDL * NRepeat == NPerBlock
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
|
||||
make_multi_index(-NPerBlock, ScalesPerKBlockSize));
|
||||
b_scale_thread_copy_up.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize));
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// Local prefill A1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
|
||||
|
||||
// Global prefetch A2
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
|
||||
// Prefetch a_scales to buf 1
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(I0, I0, I0),
|
||||
a_scale_thread_bufs(I1));
|
||||
|
||||
// restore row id and advance to the next set of scales
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
make_multi_index(0, ScalesPerKBlockSize, 0));
|
||||
|
||||
// Prefetch b_scales to buf 1
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
constexpr auto b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
|
||||
auto b_scale_thread_buf_copy =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf_copy);
|
||||
|
||||
b_scale_thread_bufs(I1)(Number<b_scale_offset>{}) =
|
||||
b_scale_thread_buf_copy[Number<0>{}];
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
|
||||
auto b_scale_thread_buf_copy_up =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
b_scale_thread_copy_up.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf_up,
|
||||
b_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf_copy_up);
|
||||
|
||||
b_scale_thread_bufs_up(I1)(Number<b_scale_offset>{}) =
|
||||
b_scale_thread_buf_copy_up[Number<0>{}];
|
||||
b_scale_thread_copy_up.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
});
|
||||
});
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
|
||||
b_scale_thread_copy_up.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
|
||||
});
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
|
||||
make_multi_index(-NPerBlock, ScalesPerKBlockSize));
|
||||
b_scale_thread_copy_up.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize));
|
||||
|
||||
// Local prefetch A1
|
||||
block_sync_lds();
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
constexpr auto k_step = k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) {
|
||||
constexpr auto a_k_step_chunk =
|
||||
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<a_k_step_chunk>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k, Number<chunk * KThreadChunk>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
c_thread_buf_up.Clear();
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
// loop over k with the step KPerBlock
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(local_read_buf));
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
b_blockwise_copy_up.Run(b_grid_desc,
|
||||
b_grid_buf_up,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs_up(local_read_buf));
|
||||
b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf);
|
||||
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec_up;
|
||||
|
||||
static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs[mfma_reg_buf]
|
||||
[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
b_thread_vec_up.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs_up[mfma_reg_buf]
|
||||
[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
|
||||
constexpr index_t b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
|
||||
|
||||
static_assert(
|
||||
0 < ScalesPerXdlopsRunPerThread,
|
||||
"Must have at least one scale per Xdlops per Thread.");
|
||||
|
||||
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread>
|
||||
a_scale_thread_vec;
|
||||
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread>
|
||||
b_scale_thread_vec;
|
||||
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread>
|
||||
b_scale_thread_vec_up;
|
||||
|
||||
// Pack scale_thread_buf into scale_thread_vec
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
|
||||
a_scale_thread_bufs[mfma_reg_buf]
|
||||
[Number<a_scale_offset + s>{}];
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs[mfma_reg_buf]
|
||||
[Number<b_scale_offset + s>{}];
|
||||
b_scale_thread_vec_up.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs_up[mfma_reg_buf]
|
||||
[Number<b_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops /
|
||||
APackedSize>::type;
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops /
|
||||
BPackedSize>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
// MFMA accumulation
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>(),
|
||||
b_thread_vec.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>(),
|
||||
b_thread_vec_up.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec_up.template AsType<BScaleDataType>(),
|
||||
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// a thread copy
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
constexpr auto k_step =
|
||||
k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}(
|
||||
[&](auto chunk) {
|
||||
constexpr auto a_k_step_chunk =
|
||||
k_step + chunk * KThreadChunk *
|
||||
xdlops_gemm.mfma_instr.num_input_blks;
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<a_k_step_chunk>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k, Number<chunk * KThreadChunk>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Prefetch a_scales
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(I0, I0, I0),
|
||||
a_scale_thread_bufs(mfma_reg_buf));
|
||||
|
||||
// restore row id and advance to the next set of scales
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, make_multi_index(0, ScalesPerKBlockSize, 0));
|
||||
|
||||
// Prefetch b_scales
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
constexpr auto b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
|
||||
auto b_scale_thread_buf_copy =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf_copy);
|
||||
|
||||
b_scale_thread_bufs(mfma_reg_buf)(Number<b_scale_offset>{}) =
|
||||
b_scale_thread_buf_copy[Number<0>{}];
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
|
||||
auto b_scale_thread_buf_copy_up =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
b_scale_thread_copy_up.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf_up,
|
||||
b_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf_copy_up);
|
||||
|
||||
b_scale_thread_bufs_up(mfma_reg_buf)(Number<b_scale_offset>{}) =
|
||||
b_scale_thread_buf_copy_up[Number<0>{}];
|
||||
b_scale_thread_copy_up.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
});
|
||||
});
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
|
||||
b_scale_thread_copy_up.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
|
||||
});
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize));
|
||||
b_scale_thread_copy_up.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize));
|
||||
};
|
||||
|
||||
LoopFunc(I0, I1);
|
||||
LoopFunc(I1, I0);
|
||||
|
||||
i += 2;
|
||||
} while(i < (num_loop - 2));
|
||||
}
|
||||
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I1));
|
||||
|
||||
b_blockwise_copy_up.Run(b_grid_desc,
|
||||
b_grid_buf_up,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs_up(I1));
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec_up;
|
||||
|
||||
static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
b_thread_vec_up.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
|
||||
|
||||
constexpr index_t b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
|
||||
|
||||
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread> a_scale_thread_vec;
|
||||
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread> b_scale_thread_vec;
|
||||
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread>
|
||||
b_scale_thread_vec_up;
|
||||
|
||||
// Pack b_scale_thread_buf into b_scale_thread_vec
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
|
||||
a_scale_thread_bufs[I0][Number<a_scale_offset + s>{}];
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs[I0][Number<b_scale_offset + s>{}];
|
||||
b_scale_thread_vec_up.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs_up[I0][Number<b_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops / APackedSize>::type;
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
// MFMA accumulation
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>(),
|
||||
b_thread_vec.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>(),
|
||||
b_thread_vec_up.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec_up.template AsType<BScaleDataType>(),
|
||||
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// a thread copy
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
constexpr auto k_step =
|
||||
k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) {
|
||||
constexpr auto a_k_step_chunk =
|
||||
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<a_k_step_chunk>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k, Number<chunk * KThreadChunk>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec_up;
|
||||
|
||||
static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
b_thread_vec_up.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs_up[I1][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
|
||||
|
||||
constexpr index_t b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
|
||||
|
||||
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread> a_scale_thread_vec;
|
||||
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread> b_scale_thread_vec;
|
||||
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread>
|
||||
b_scale_thread_vec_up;
|
||||
|
||||
// Pack b_scale_thread_buf into b_scale_thread_vec
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
|
||||
a_scale_thread_bufs[I1][Number<a_scale_offset + s>{}];
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs[I1][Number<b_scale_offset + s>{}];
|
||||
b_scale_thread_vec_up.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs_up[I1][Number<b_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops / APackedSize>::type;
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
// MFMA accumulation
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>(),
|
||||
b_thread_vec.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>(),
|
||||
b_thread_vec_up.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec_up.template AsType<BScaleDataType>(),
|
||||
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec_up;
|
||||
|
||||
static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
b_thread_vec_up.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
|
||||
|
||||
constexpr index_t b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
|
||||
|
||||
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread> a_scale_thread_vec;
|
||||
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread> b_scale_thread_vec;
|
||||
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread>
|
||||
b_scale_thread_vec_up;
|
||||
|
||||
// Pack b_scale_thread_buf into b_scale_thread_vec
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
|
||||
a_scale_thread_bufs[I0][Number<a_scale_offset + s>{}];
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs[I0][Number<b_scale_offset + s>{}];
|
||||
b_scale_thread_vec_up.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs_up[I0][Number<b_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops / APackedSize>::type;
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
// MFMA accumulation
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>(),
|
||||
b_thread_vec.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>(),
|
||||
b_thread_vec_up.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec_up.template AsType<BScaleDataType>(),
|
||||
c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: make this field protected when a_scale_thread_copy_ is moved
|
||||
// here
|
||||
static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MRepeat>{}, Number<KRepeat>{}, Number<ScalesPerXdlopsRunPerThread>{}));
|
||||
|
||||
// Is used to copy data from a_scale_grid to a_scale_thread
|
||||
static constexpr auto a_scale_thread_desc_copy =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{}));
|
||||
|
||||
// TODO: make this field protected when b_scale_thread_copy_ is moved
|
||||
// here
|
||||
static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<NRepeat>{}, Number<KRepeat>{}, Number<ScalesPerXdlopsRunPerThread>{}));
|
||||
|
||||
// Is used to copy data from b_scale_grid to b_scale_thread_buf
|
||||
static constexpr auto b_scale_thread_desc_copy =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{}));
|
||||
|
||||
protected:
|
||||
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}));
|
||||
using Base::a_thread_copy_;
|
||||
using Base::a_thread_desc_;
|
||||
using Base::b_thread_copy_;
|
||||
// using Base::b_thread_desc_;
|
||||
using Base::c_thread_desc_;
|
||||
|
||||
static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
File diff suppressed because it is too large
Load Diff
@@ -3,8 +3,6 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_gufusion_v3.hpp"
|
||||
|
||||
@@ -43,54 +41,11 @@ constexpr auto BlockGemmMXBPreshufflePipeline_Selector()
|
||||
{
|
||||
if constexpr(GUFusion)
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_gufusion_v1<
|
||||
BlkGemmPipeSche,
|
||||
ThreadBlockSize,
|
||||
ScaleBlockSize,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
;
|
||||
return nullptr;
|
||||
}
|
||||
else
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1<
|
||||
BlkGemmPipeSche,
|
||||
ThreadBlockSize,
|
||||
ScaleBlockSize,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
|
||||
@@ -1,813 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Naive pipeline with lowest resource request per WGP
|
||||
// GlobalPrefetchStages: 2
|
||||
// LocalPreFillStages: 1
|
||||
// LocalPreFetchStages: 1
|
||||
// LocalSharedMemoryBuffer: 1
|
||||
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
|
||||
index_t ThreadBlockSize,
|
||||
index_t ScaleBlockSize,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat, // MXdlPerWave
|
||||
index_t NRepeat, // NXdlPerWave
|
||||
index_t KPack>
|
||||
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1
|
||||
{
|
||||
};
|
||||
|
||||
template <index_t ThreadBlockSize,
|
||||
index_t ScaleBlockSize,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat, // MXdlPerWave
|
||||
index_t NRepeat, // NXdlPerWave
|
||||
index_t KPack>
|
||||
struct BlockwiseGemmXdlops_pipeline_bpreshuffle_mx_moe_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
ThreadBlockSize,
|
||||
ScaleBlockSize,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
: BlockwiseGemmXdlops_mx_pipeline_base<ThreadBlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
|
||||
{
|
||||
|
||||
using Base = BlockwiseGemmXdlops_mx_pipeline_base<ThreadBlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::KRepeat;
|
||||
using Base::MWaves;
|
||||
using Base::NWaves;
|
||||
using Base::WaveSize;
|
||||
using Base::xdlops_gemm;
|
||||
|
||||
using Base::CalculateCThreadOriginDataIndex;
|
||||
using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::GetCThreadBuffer;
|
||||
using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::GetWaveIdx;
|
||||
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
|
||||
using Base::a_block_desc_m0_m1_m2_k;
|
||||
using Base::b_block_desc_n0_n1_n2_k;
|
||||
|
||||
using Base::AMmaKStride;
|
||||
using Base::BMmaKStride;
|
||||
using Base::KThreadChunk;
|
||||
|
||||
using Base::APackedSize;
|
||||
using Base::BPackedSize;
|
||||
using Base::ComputePackedSize;
|
||||
|
||||
using AccType = typename Base::AccType;
|
||||
using Tuple4 = typename Base::Tuple4;
|
||||
using ComputeTypeA = typename Base::ComputeTypeA;
|
||||
using ComputeTypeB = typename Base::ComputeTypeB;
|
||||
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 2;
|
||||
|
||||
template <typename TileDesc_M0_M1_M2_K>
|
||||
__host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&)
|
||||
{
|
||||
constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
|
||||
constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
|
||||
constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
|
||||
constexpr index_t K2 = KPack;
|
||||
constexpr index_t K1 = 64 / NPerXDL;
|
||||
constexpr index_t K0 = KRepeat;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
TileDesc_M0_M1_M2_K{},
|
||||
make_tuple(
|
||||
make_pass_through_transform(Number<M0>{}),
|
||||
make_pass_through_transform(Number<M1>{}),
|
||||
make_pass_through_transform(Number<M2>{}),
|
||||
make_unmerge_transform(make_tuple(Number<K0>{}, Number<K1>{}, Number<K2>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3, 4, 5>{}));
|
||||
}
|
||||
|
||||
static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 =
|
||||
MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k);
|
||||
|
||||
static constexpr auto ScalesPerKBlockSize =
|
||||
KPerBlock / ScaleBlockSize; // How many mx-vectors per K block
|
||||
|
||||
//> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run()
|
||||
static constexpr auto ScalesPerXdlopsRun = (KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize;
|
||||
|
||||
//> How many scales a thread must read to accommodate one call to xdlops_gemm.Run()
|
||||
static constexpr auto ScalesPerXdlopsRunPerThread =
|
||||
ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks;
|
||||
|
||||
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
TailNumber TailNum,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename CThreadBuffer,
|
||||
typename AScaleGridBuffer,
|
||||
typename AScaleGridDesc,
|
||||
typename AScaleThreadTransfer,
|
||||
typename BScaleGridBuffer,
|
||||
typename BScaleGridDesc,
|
||||
typename BScaleThreadTransfer>
|
||||
__device__ void Run(
|
||||
// ABlockCopy
|
||||
const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
// BBlockCopy
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
// CThread
|
||||
CThreadBuffer& c_thread_buf,
|
||||
// A and B scales
|
||||
const AScaleGridDesc& a_scale_grid_desc,
|
||||
AScaleThreadTransfer& a_scale_thread_copy,
|
||||
const AScaleGridBuffer& a_scale_grid_buf,
|
||||
const BScaleGridDesc& b_scale_grid_desc,
|
||||
BScaleThreadTransfer& b_scale_thread_copy,
|
||||
const BScaleGridBuffer& b_scale_grid_buf,
|
||||
index_t num_loop) const
|
||||
{
|
||||
ignore = b_block_desc;
|
||||
ignore = b_block_buf;
|
||||
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
|
||||
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
|
||||
|
||||
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
|
||||
a_scale_thread_desc.GetElementSpaceSize());
|
||||
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc.GetElementSpaceSize());
|
||||
|
||||
StaticallyIndexedArray<decltype(a_scale_thread_buf), Number<2>{}> a_scale_thread_bufs;
|
||||
StaticallyIndexedArray<decltype(b_scale_thread_buf), Number<2>{}> b_scale_thread_bufs;
|
||||
|
||||
// Global prefetch A1 B1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I0));
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Prefetch a_scales
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
constexpr auto a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s));
|
||||
auto a_scale_thread_buf_copy =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
|
||||
a_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
a_scale_thread_buf_copy);
|
||||
|
||||
a_scale_thread_buf(I0)(Number<a_scale_offset>{}) =
|
||||
a_scale_thread_buf_copy[Number<0>{}];
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
});
|
||||
});
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize));
|
||||
});
|
||||
|
||||
// restore row id and advance to the next set of scales
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
make_multi_index(-MPerBlock, ScalesPerKBlockSize));
|
||||
|
||||
// Prefetch b_scales to buf 0
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
constexpr auto b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
|
||||
auto b_scale_thread_buf_copy =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf_copy);
|
||||
|
||||
b_scale_thread_bufs(I0)(Number<b_scale_offset>{}) =
|
||||
b_scale_thread_buf_copy[Number<0>{}];
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
});
|
||||
});
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
|
||||
});
|
||||
|
||||
// restore col id and advance to the next set of scales
|
||||
// NWaves * NPerXDL * NRepeat == NPerBlock
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
|
||||
make_multi_index(-NPerBlock, ScalesPerKBlockSize));
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// Local prefill A1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
|
||||
|
||||
// Global prefetch A2
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
|
||||
// Prefetch a_scales to buf 1
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
constexpr auto a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s));
|
||||
auto a_scale_thread_buf_copy =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
|
||||
a_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
a_scale_thread_buf_copy);
|
||||
|
||||
a_scale_thread_buf(I1)(Number<a_scale_offset>{}) =
|
||||
a_scale_thread_buf_copy[Number<0>{}];
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
});
|
||||
});
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize));
|
||||
});
|
||||
|
||||
// restore row id and advance to the next set of scales
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
make_multi_index(-MPerBlock, ScalesPerKBlockSize));
|
||||
|
||||
// Prefetch b_scales to buf 1
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
constexpr auto b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
|
||||
auto b_scale_thread_buf_copy =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf_copy);
|
||||
|
||||
b_scale_thread_bufs(I1)(Number<b_scale_offset>{}) =
|
||||
b_scale_thread_buf_copy[Number<0>{}];
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
});
|
||||
});
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
|
||||
});
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
|
||||
make_multi_index(-NPerBlock, ScalesPerKBlockSize));
|
||||
|
||||
// Local prefetch A1
|
||||
block_sync_lds();
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
constexpr auto k_step = k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) {
|
||||
constexpr auto a_k_step_chunk =
|
||||
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<a_k_step_chunk>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k, Number<chunk * KThreadChunk>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
// loop over k with the step KPerBlock
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(local_read_buf));
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, mfma_reg_buf);
|
||||
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs[mfma_reg_buf]
|
||||
[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
|
||||
constexpr index_t b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
|
||||
|
||||
static_assert(
|
||||
0 < ScalesPerXdlopsRunPerThread,
|
||||
"Must have at least one scale per Xdlops per Thread.");
|
||||
|
||||
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread>
|
||||
a_scale_thread_vec;
|
||||
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread>
|
||||
b_scale_thread_vec;
|
||||
|
||||
// Pack scale_thread_buf into scale_thread_vec
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
|
||||
a_scale_thread_bufs[mfma_reg_buf]
|
||||
[Number<a_scale_offset + s>{}];
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs[mfma_reg_buf]
|
||||
[Number<b_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops /
|
||||
APackedSize>::type;
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops /
|
||||
BPackedSize>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
// MFMA accumulation
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>(),
|
||||
b_thread_vec.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// a thread copy
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
constexpr auto k_step =
|
||||
k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}(
|
||||
[&](auto chunk) {
|
||||
constexpr auto a_k_step_chunk =
|
||||
k_step + chunk * KThreadChunk *
|
||||
xdlops_gemm.mfma_instr.num_input_blks;
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<a_k_step_chunk>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k, Number<chunk * KThreadChunk>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
// Prefetch a_scales
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(I0, I0, I0),
|
||||
a_scale_thread_bufs(mfma_reg_buf));
|
||||
|
||||
// restore row id and advance to the next set of scales
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, make_multi_index(0, ScalesPerKBlockSize, 0));
|
||||
|
||||
// Prefetch b_scales
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
constexpr auto b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
|
||||
auto b_scale_thread_buf_copy =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc_copy.GetElementSpaceSize());
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc_copy,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf_copy);
|
||||
|
||||
b_scale_thread_bufs(mfma_reg_buf)(Number<b_scale_offset>{}) =
|
||||
b_scale_thread_buf_copy[Number<0>{}];
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
|
||||
});
|
||||
});
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc,
|
||||
make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
|
||||
});
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize));
|
||||
};
|
||||
|
||||
LoopFunc(I0, I1);
|
||||
LoopFunc(I1, I0);
|
||||
|
||||
i += 2;
|
||||
} while(i < (num_loop - 2));
|
||||
}
|
||||
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_n0_n1_k0_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I1));
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
|
||||
|
||||
constexpr index_t b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
|
||||
|
||||
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread> a_scale_thread_vec;
|
||||
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread> b_scale_thread_vec;
|
||||
|
||||
// Pack b_scale_thread_buf into b_scale_thread_vec
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
|
||||
a_scale_thread_bufs[I0][Number<a_scale_offset + s>{}];
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs[I0][Number<b_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops / APackedSize>::type;
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
// MFMA accumulation
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>(),
|
||||
b_thread_vec.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// a thread copy
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
constexpr auto k_step =
|
||||
k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) {
|
||||
constexpr auto a_k_step_chunk =
|
||||
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<a_k_step_chunk>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k, Number<chunk * KThreadChunk>{}),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
|
||||
|
||||
constexpr index_t b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
|
||||
|
||||
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread> a_scale_thread_vec;
|
||||
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread> b_scale_thread_vec;
|
||||
|
||||
// Pack b_scale_thread_buf into b_scale_thread_vec
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
|
||||
a_scale_thread_bufs[I1][Number<a_scale_offset + s>{}];
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs[I1][Number<b_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops / APackedSize>::type;
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
// MFMA accumulation
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>(),
|
||||
b_thread_vec.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeTypeA, KPack> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack / ComputePackedSize, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
|
||||
|
||||
constexpr index_t b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
|
||||
|
||||
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread> a_scale_thread_vec;
|
||||
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread> b_scale_thread_vec;
|
||||
|
||||
// Pack b_scale_thread_buf into b_scale_thread_vec
|
||||
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
|
||||
a_scale_thread_bufs[I0][Number<a_scale_offset + s>{}];
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
|
||||
b_scale_thread_bufs[I0][Number<b_scale_offset + s>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type_a =
|
||||
typename vector_type<ComputeTypeA,
|
||||
xdlops_gemm.K1PerXdlops / APackedSize>::type;
|
||||
using mfma_input_type_b =
|
||||
typename vector_type<ComputeTypeB,
|
||||
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
// MFMA accumulation
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type_a>(),
|
||||
a_scale_thread_vec.template AsType<AScaleDataType>(),
|
||||
b_thread_vec.template AsType<mfma_input_type_b>(),
|
||||
b_scale_thread_vec.template AsType<BScaleDataType>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: make this field protected when a_scale_thread_copy_ is moved
|
||||
// here
|
||||
static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MRepeat>{}, Number<KRepeat>{}, Number<ScalesPerXdlopsRunPerThread>{}));
|
||||
|
||||
// Is used to copy data from a_scale_grid to a_scale_thread
|
||||
static constexpr auto a_scale_thread_desc_copy =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{}));
|
||||
|
||||
// TODO: make this field protected when b_scale_thread_copy_ is moved
|
||||
// here
|
||||
static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<NRepeat>{}, Number<KRepeat>{}, Number<ScalesPerXdlopsRunPerThread>{}));
|
||||
|
||||
// Is used to copy data from b_scale_grid to b_scale_thread_buf
|
||||
static constexpr auto b_scale_thread_desc_copy =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{}));
|
||||
|
||||
protected:
|
||||
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}));
|
||||
using Base::a_thread_copy_;
|
||||
using Base::a_thread_desc_;
|
||||
using Base::b_thread_copy_;
|
||||
// using Base::b_thread_desc_;
|
||||
using Base::c_thread_desc_;
|
||||
|
||||
static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,109 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_moe_gufusion_v3.hpp"
|
||||
|
||||
namespace ck {
|
||||
template <BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSche,
|
||||
index_t ThreadBlockSize,
|
||||
index_t ScaleBlockSize,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename ComputeDataType, // TODO: remove this as in this pipeline ADataType and BDataType
|
||||
// must be used for compute
|
||||
typename AccDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack,
|
||||
bool GUFusion = false>
|
||||
constexpr auto BlockGemmMXPipeline_Selector()
|
||||
{
|
||||
|
||||
// Hardware MX GEMM pipeline
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if constexpr(GUFusion)
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
else
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if constexpr(GUFusion)
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_mx_moe_bns_gufusion_v3<
|
||||
BlkGemmPipeSche,
|
||||
ThreadBlockSize,
|
||||
ScaleBlockSize,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3<BlkGemmPipeSche,
|
||||
ThreadBlockSize,
|
||||
ScaleBlockSize,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "MX GEMM Pipeline configuration is not available" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,405 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_description/cluster_descriptor.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
/**
|
||||
* Transfer that uses direct load instructions to copy data from global to LDS memory.
|
||||
*
|
||||
* Traditional loads first copy data from global to registers, and then from registers to LDS.
|
||||
* Direct loads do not need an intermediate step, data is copied directly from global to LDS,
|
||||
* without the use of additional registers.
|
||||
*
|
||||
* However, the instruction has limitations:
|
||||
* - each thread must copy exactly a single DWORD - 4 bytes;
|
||||
* - threads within a single wavefront must write consecutive DWORDS into LDS,
|
||||
* (data in global do not need to be contiguous, each thread might have its own offset).
|
||||
*
|
||||
* To make sure that all the transfers finished, the `waitcnt` instruction must be used with
|
||||
* `vmcnt` instead of `lgkmcnt`.
|
||||
*
|
||||
* Limitations of the transfer class:
|
||||
* - `SrcData` must be the same as `DstData` - no possibility to convert the data type in flight;
|
||||
* - `DstVectorDim` must be the last dimension;
|
||||
* - `SrcVectorDim` must be the last dimension if `ScalarPerVector` is greater than 1;
|
||||
* - `ScalarPerVector` times the number of bytes of `DstData` must be equal to a single DWORD = 4B
|
||||
* (for examlpe if `DstData` is fp32, then `ScalarPerVector` must be 1; if `DstData` is fp16,
|
||||
* `ScalarPerVector` must be 2);
|
||||
* - if `ScalarPerVector` is greater than 1, the contiguous dimension in src and dst must be
|
||||
* the same dimension;
|
||||
* - threads in a wavefront must write contiguous data to LDS (when wavefront size is 64,
|
||||
* they must write 64 contiguous DWORDs) - `ThreadClusterLengths` must be prepared in such a way
|
||||
* to guarantee that.
|
||||
*/
|
||||
template <typename ThreadGroup,
|
||||
typename BlockSliceLengths,
|
||||
typename ThreadClusterLengths,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename SrcDimAccessOrder,
|
||||
index_t SrcVectorDim,
|
||||
index_t DstVectorDim,
|
||||
index_t ScalarPerVector,
|
||||
typename IndexType,
|
||||
index_t GatherDim = 1>
|
||||
struct ThreadGroupTensorSliceTransfer_Gather_DirectLoad
|
||||
{
|
||||
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
|
||||
using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
|
||||
|
||||
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
|
||||
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static constexpr auto block_slice_lengths = BlockSliceLengths{};
|
||||
static constexpr auto thread_cluster_lengths = ThreadClusterLengths{};
|
||||
|
||||
static constexpr auto thread_single_load_size = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, ScalarPerVector>{}, Number<nDim>{});
|
||||
// After a load, each thread moves by `thread_steps` instead of loading the next elements.
|
||||
// It makes the whole wavefront load contiguous memory, what is required for direct loads.
|
||||
static constexpr auto thread_steps = thread_cluster_lengths * thread_single_load_size;
|
||||
static constexpr auto thread_slice_lengths = block_slice_lengths / thread_steps;
|
||||
static constexpr index_t gather_num = thread_slice_lengths.At(Number<GatherDim>{});
|
||||
|
||||
static __device__ constexpr bool AreThreadClusterLengthsValid()
|
||||
{
|
||||
// Make sure that ThreadClusterLengths are set in a way that allows for contiguous writes to
|
||||
// LDS by the threads from a single wavefront.
|
||||
// Examples (assuming 64 threads in a wavefront, 128 in a thread block):
|
||||
// 1. BlockSliceLengths = [K0PerBlock, MPerBlock, K1PerBlock] = [4, 128, 8],
|
||||
// data type = fp32 -> ScalarPerVector = 1
|
||||
// INVALID: ThreadClusterLengths = [4, 4, 8] since in the first iteration, threads 0-31
|
||||
// write [0, 0, 0] - [0, 3, 7] and thread 32 writes [1, 0, 0] instead of
|
||||
// [0, 4, 0].
|
||||
// VALID: ThreadClusterLengths = [2, 8, 8] or [1, 16, 8] since in the first iteration,
|
||||
// threads 0-63 write [0, 0, 0] - [0, 7, 7] -> 64 consecutive elements (DWORDs).
|
||||
// 2. BlockSliceLengths = [K0PerBlock, MPerBlock, K1PerBlock] = [4, 128, 8],
|
||||
// data type = fp16 -> ScalarPerVector = 2
|
||||
// NOTE: ThreadClusterLengths must take into account that each thread writes two
|
||||
// elements (single DWORD) along the contiguous dimension.
|
||||
// INVALID: ThreadClusterLengths = [4, 4, 8] since each 8 threads would try to write
|
||||
// 8 * 2 elements of K1PerBlock and there are only 8;
|
||||
// ThreadClusterLengths = [4, 8, 4] since in the first iteration, threads 0-31
|
||||
// write [0, 0, 0] - [0, 7, 7] (7 since each writes 2 elements) and thread 32
|
||||
// writes [1, 0, 0] instead of [0, 8, 0].
|
||||
// VALID: ThreadClusterLengths = [4, 16, 4] or [2, 32, 4] or [1, 64, 4] since in the
|
||||
// first iteration, threads 0-63 write [0, 0, 0] - [0, 15, 7] -> 128 consecutive
|
||||
// elements = 64 consecutive DWORDs.
|
||||
#if defined(__gfx950__)
|
||||
int num_contiguous_dwords = 4;
|
||||
#else
|
||||
int num_contiguous_dwords = 1;
|
||||
#endif
|
||||
bool is_contiguous = true;
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
if(is_contiguous)
|
||||
{
|
||||
num_contiguous_dwords *= thread_cluster_lengths[nDim - i - 1];
|
||||
}
|
||||
if(thread_slice_lengths[nDim - i - 1] > 1)
|
||||
{
|
||||
is_contiguous = false;
|
||||
}
|
||||
});
|
||||
constexpr index_t wavefront_size = get_warp_size();
|
||||
const bool wave_contiguous = num_contiguous_dwords % wavefront_size == 0;
|
||||
|
||||
bool thread_slice_lengths_correct = true;
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
if(thread_slice_lengths[i] <= 0)
|
||||
{
|
||||
thread_slice_lengths_correct = false;
|
||||
}
|
||||
});
|
||||
|
||||
return wave_contiguous && thread_slice_lengths_correct;
|
||||
}
|
||||
|
||||
__device__ constexpr ThreadGroupTensorSliceTransfer_Gather_DirectLoad(
|
||||
const SrcDesc& src_desc,
|
||||
const Index& src_block_slice_origin,
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_block_slice_origin,
|
||||
const StaticallyIndexedArray<IndexType, gather_num>& gather_offsets)
|
||||
: gather_offsets_(gather_offsets)
|
||||
{
|
||||
static_assert(ck::is_same_v<SrcData, DstData>,
|
||||
"Direct load transfer does not support datatypes conversion. Source and "
|
||||
"destination data types must be the same.");
|
||||
|
||||
static_assert(
|
||||
DstVectorDim == nDim - 1,
|
||||
"Direct load transfer requires the destination vector dimension to be the last one.");
|
||||
|
||||
static_assert(ScalarPerVector == 1 || SrcVectorDim == DstVectorDim,
|
||||
"When loading more than one element per thread at once, the contiguous "
|
||||
"dimension must be the same between source and destination.");
|
||||
|
||||
// constexpr auto dword_bytes = 4;
|
||||
// constexpr auto bytes_per_thread_load = ScalarPerVector * sizeof(SrcData);
|
||||
// static_assert(bytes_per_thread_load == dword_bytes,
|
||||
// "Direct load transfer requires each thread to load exactly a single "
|
||||
// "DWORD of data.");
|
||||
|
||||
static_assert(nDim == remove_cvref_t<SrcDesc>::GetNumOfDimension() &&
|
||||
nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() &&
|
||||
nDim == ThreadClusterLengths::Size(),
|
||||
"Inconsistent number of dimensions across lengths and descriptors.");
|
||||
|
||||
static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
|
||||
"The number of threads cannot be less than the number of elements in "
|
||||
"thread cluster lengths.");
|
||||
|
||||
// static_assert(
|
||||
// AreThreadClusterLengthsValid(),
|
||||
// "Thread cluster lengths are incorrect. They must be set in a way that allows a single
|
||||
// " "wavefront to write contiguous DWORDs into LDS memory. ");
|
||||
|
||||
const auto thread_cluster_idx =
|
||||
thread_cluster_desc_.CalculateBottomIndex(make_multi_index(ThreadGroup::GetThreadId()));
|
||||
|
||||
constexpr auto wave_cluster_lengths = generate_sequence_v2(
|
||||
[&](auto i) {
|
||||
if constexpr(ThreadClusterArrangeOrder{}.At(i) == (nDim - 3))
|
||||
{
|
||||
return Number<ThreadGroup::GetNumOfThread() / 64>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return I1;
|
||||
}
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
constexpr auto wave_thread_cluster_lengths = ThreadClusterLengths{} / wave_cluster_lengths;
|
||||
constexpr auto wave_single_load_size =
|
||||
wave_thread_cluster_lengths * thread_single_load_size;
|
||||
constexpr auto wave_cluster_desc_ =
|
||||
make_cluster_descriptor(wave_cluster_lengths, ThreadClusterArrangeOrder{});
|
||||
|
||||
const auto wave_cluster_idx = wave_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index(ThreadGroup::GetThreadId() / 64));
|
||||
|
||||
const auto thread_data_idx_begin = thread_cluster_idx * thread_single_load_size;
|
||||
const auto wave_data_idx_begin = wave_cluster_idx * wave_single_load_size;
|
||||
|
||||
SetSrcSliceOrigin(src_desc, src_block_slice_origin + thread_data_idx_begin);
|
||||
// We don't need threadwise offset for lds since it was calculate by HW
|
||||
// We still need input the wavewise offset.
|
||||
SetDstSliceOrigin(dst_desc, dst_block_slice_origin + wave_data_idx_begin);
|
||||
}
|
||||
|
||||
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
|
||||
{
|
||||
auto adjusted_src_origin_idx = [&]() {
|
||||
Index idx;
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
idx(i) = i.value == GatherDim ? 0 : src_slice_origin_idx[Number<i>{}];
|
||||
});
|
||||
return idx;
|
||||
}();
|
||||
|
||||
// CK_PRINT<decltype(adjusted_src_origin_idx)>();
|
||||
// CK_PRINT<decltype(src_slice_origin_idx)>();
|
||||
|
||||
src_coord_ = make_tensor_coordinate(src_desc, adjusted_src_origin_idx);
|
||||
src_slice_origin_ = adjusted_src_origin_idx;
|
||||
}
|
||||
|
||||
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
|
||||
{
|
||||
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
|
||||
dst_slice_origin_ = dst_slice_origin_idx;
|
||||
}
|
||||
|
||||
__device__ void ResetDstSliceWindow(const DstDesc& dst_desc)
|
||||
{
|
||||
dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_);
|
||||
}
|
||||
|
||||
template <typename SrcBuffer, typename DstBuffer>
|
||||
__device__ void Run(const SrcDesc& src_desc,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf)
|
||||
{
|
||||
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global,
|
||||
"Source data must come from a global memory buffer.");
|
||||
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
|
||||
"Destination data must be stored in an LDS memory buffer.");
|
||||
|
||||
static_assert(
|
||||
ck::is_same_v<remove_cvref_t<typename SrcBuffer::type>, remove_cvref_t<SrcData>>,
|
||||
"SrcBuffer and SrcData data types must be consistent.");
|
||||
static_assert(
|
||||
ck::is_same_v<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>,
|
||||
"DstBuffer and DstData data types must be consistent.");
|
||||
|
||||
constexpr auto dst_access_lengths = thread_slice_lengths;
|
||||
|
||||
const auto dst_forward_steps = generate_steps(dst_desc, 1);
|
||||
const auto dst_backward_steps = generate_steps(dst_desc, -1);
|
||||
const auto src_forward_steps = generate_steps(src_desc, 1);
|
||||
const auto src_backward_steps = generate_steps(src_desc, -1);
|
||||
|
||||
// Loop over the destination block and copy data.
|
||||
static_ford<decltype(dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
|
||||
IndexType gather_offset = gather_offsets_[ordered_dst_access_idx[Number<GatherDim>{}]];
|
||||
// src_coord_xor_ = src_coord_;
|
||||
// src_coord_xor_.GetIndex().At(I0) =
|
||||
// src_coord_.GetIndex().At(I0) ^ ((threadIdx.x % 64) / 8);
|
||||
Index new_index = src_coord_.GetIndex();
|
||||
new_index(I0) = src_coord_.GetIndex().At(I0) ^ ((threadIdx.x % 64) / 8);
|
||||
src_coord_xor_ = make_tensor_coordinate(src_desc, new_index);
|
||||
|
||||
const IndexType src_offset = src_coord_xor_.GetOffset() + gather_offset;
|
||||
const IndexType dst_offset = __builtin_amdgcn_readfirstlane(dst_coord_.GetOffset());
|
||||
|
||||
// Check if src data is not in the logic padding area.
|
||||
// Leave the HW for oob checking
|
||||
// const bool is_src_valid =
|
||||
// coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc,
|
||||
// src_coord_);
|
||||
|
||||
src_buf.template DirectCopyToLds<remove_cvref_t<decltype(dst_buf)>, ScalarPerVector>(
|
||||
dst_buf, src_offset, dst_offset, true);
|
||||
|
||||
constexpr auto move_src_on_dim = [&]() constexpr
|
||||
{
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
move_on_dim_(i) = ordered_dst_access_idx[i] < dst_access_lengths[i] - 1;
|
||||
|
||||
static_for<i + 1, nDim, 1>{}([&](auto j) {
|
||||
move_on_dim_(i) &= ordered_dst_access_idx[j] == dst_access_lengths[j] - 1;
|
||||
});
|
||||
move_on_dim_(i) &= i.value != GatherDim;
|
||||
});
|
||||
|
||||
return move_on_dim_;
|
||||
}
|
||||
();
|
||||
|
||||
constexpr auto move_dst_on_dim = [&]() constexpr
|
||||
{
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
move_on_dim_(i) = ordered_dst_access_idx[i] < dst_access_lengths[i] - 1;
|
||||
|
||||
static_for<i + 1, nDim, 1>{}([&](auto j) {
|
||||
move_on_dim_(i) &= ordered_dst_access_idx[j] == dst_access_lengths[j] - 1;
|
||||
});
|
||||
});
|
||||
|
||||
return move_on_dim_;
|
||||
}
|
||||
();
|
||||
|
||||
// Decide whether to move forward or backward.
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_dst_access_idx[I0];
|
||||
|
||||
static_for<1, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * dst_access_lengths[j] + ordered_dst_access_idx[j];
|
||||
});
|
||||
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
// Move the source coordinate.
|
||||
if constexpr(move_src_on_dim[i])
|
||||
{
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
move_tensor_coordinate(src_desc, src_coord_, src_forward_steps[i]);
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tensor_coordinate(src_desc, src_coord_, src_backward_steps[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// Move the destination coordinate.
|
||||
if constexpr(move_dst_on_dim[i])
|
||||
{
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, dst_forward_steps[i]);
|
||||
}
|
||||
else
|
||||
{
|
||||
move_tensor_coordinate(dst_desc, dst_coord_, dst_backward_steps[i]);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// Reset the destination slice since the entire buffer has been already filled.
|
||||
ResetDstSliceWindow(dst_desc);
|
||||
}
|
||||
|
||||
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
|
||||
{
|
||||
src_slice_origin_ = src_slice_origin_ + step;
|
||||
src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_);
|
||||
}
|
||||
|
||||
template <typename DescType>
|
||||
__device__ auto generate_steps(const DescType& desc, int sign)
|
||||
{
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
Index step_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
step_idx(j) = (i.value == j.value) ? sign * thread_steps[i] : 0;
|
||||
});
|
||||
|
||||
return make_tensor_coordinate_step(desc, step_idx);
|
||||
},
|
||||
Number<nDim>{});
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr auto thread_cluster_desc_ =
|
||||
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
SrcCoord src_coord_;
|
||||
SrcCoord src_coord_xor_;
|
||||
DstCoord dst_coord_;
|
||||
Index src_slice_origin_;
|
||||
Index dst_slice_origin_;
|
||||
StaticallyIndexedArray<IndexType, gather_num> gather_offsets_;
|
||||
// static constexpr auto a_grid_xor_desc = make_naive_tensor_descriptor_packed(
|
||||
// make_tuple(Number<AK0 ^ ((threadIdx / AK0) % AK0)>{}, Number<M>{}, Number<AK1>{}));
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
@@ -194,10 +194,10 @@ struct DeviceMoeGemmMX : public DeviceMoEGemmMXBPreShuffle<ALayout,
|
||||
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
|
||||
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
|
||||
|
||||
auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
|
||||
sizeof(ADataType) / APackedSize;
|
||||
auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
|
||||
sizeof(BDataType) / BPackedSize;
|
||||
auto size_a_buffer =
|
||||
a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType);
|
||||
auto size_b_buffer =
|
||||
b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType);
|
||||
|
||||
const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
|
||||
arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
|
||||
@@ -245,49 +245,31 @@ struct DeviceMoeGemmMX : public DeviceMoEGemmMXBPreShuffle<ALayout,
|
||||
}
|
||||
};
|
||||
|
||||
constexpr auto estimated_reg_a = MPerBlock * KPerBlock * sizeof(ADataType) /
|
||||
APackedSize / BlockSize / 4 *
|
||||
(1 + GridwiseGemm::NWave);
|
||||
constexpr auto estimated_reg_b = NPerBlock * KPerBlock * sizeof(BDataType) /
|
||||
BPackedSize / BlockSize / 4 * (2) *
|
||||
(IsInputGemm ? 2 : 1);
|
||||
constexpr auto estimated_reg_c = MPerBlock * NPerBlock * sizeof(GemmAccDataType) /
|
||||
BlockSize / 4 * (IsInputGemm ? 2 : 1);
|
||||
constexpr auto estimated_reg_total =
|
||||
estimated_reg_a + estimated_reg_b + estimated_reg_c;
|
||||
|
||||
constexpr index_t minimum_occupancy = (estimated_reg_total >= 256) ? 1 : 2;
|
||||
// TODO: Check if this is the right algorithm for minimum_occupancy
|
||||
constexpr index_t minimum_occupancy =
|
||||
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave
|
||||
? (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 &&
|
||||
MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) <= 128 * 128 * 64 * 2)
|
||||
? 2
|
||||
: 1
|
||||
: 2;
|
||||
|
||||
constexpr auto MemoryDataOp =
|
||||
IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd;
|
||||
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
// Tail number always full
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
}
|
||||
const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Full>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 ||
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
@@ -315,26 +297,15 @@ struct DeviceMoeGemmMX : public DeviceMoEGemmMXBPreShuffle<ALayout,
|
||||
}
|
||||
else
|
||||
{
|
||||
// Tail number always full
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
|
||||
false,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
|
||||
false,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
|
||||
false,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Full>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
|
||||
@@ -0,0 +1,567 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bpreshuffle.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/flush_cache.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename DsDataType,
|
||||
typename CDataType,
|
||||
typename GemmAccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t ScaleBlockSize,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MXdlPerWave,
|
||||
index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
index_t ActivationOP = 0,
|
||||
bool NSwizzle = false,
|
||||
bool IsInputGemm = true,
|
||||
bool MulRoutedWeight = true,
|
||||
typename IndexType = index_t,
|
||||
typename ComputeTypeA = ADataType,
|
||||
typename ComputeTypeB = BDataType>
|
||||
struct DeviceMoeGemmMXBPreShuffle : public DeviceMoEGemmMXBPreShuffle<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
DsDataType,
|
||||
CDataType,
|
||||
ScaleBlockSize,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
using GridwiseGemm = GridwiseMoeGemmMX_BPreshuffle<
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
GemmAccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
GemmSpec,
|
||||
ScaleBlockSize,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ActivationOP,
|
||||
NSwizzle,
|
||||
IsInputGemm,
|
||||
MulRoutedWeight,
|
||||
IndexType,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>;
|
||||
|
||||
using Argument = typename GridwiseGemm::Argument;
|
||||
|
||||
static constexpr index_t APackedSize = packed_size_v<ADataType>;
|
||||
static constexpr index_t BPackedSize = packed_size_v<BDataType>;
|
||||
|
||||
int GetPreShuffleParameters() override { return NPerXDL; }
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
arg.Print();
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
|
||||
}
|
||||
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
index_t k_grain = arg.KBatch * KPerBlock;
|
||||
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
|
||||
const auto RunKernel = [&](const auto& kernel) {
|
||||
if(stream_config.flush_cache)
|
||||
{
|
||||
|
||||
std::array<std::size_t, NumDTensor> DsSize;
|
||||
|
||||
Argument arg_ = arg;
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
|
||||
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
|
||||
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
|
||||
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
|
||||
|
||||
auto size_a_buffer =
|
||||
a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType);
|
||||
auto size_b_buffer =
|
||||
b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType);
|
||||
|
||||
const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
|
||||
arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
|
||||
});
|
||||
ck::utility::RotatingMemWrapperMultiD<Argument, DsDataType> rotating_mem(
|
||||
arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck::utility::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(arg_.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
|
||||
0,
|
||||
arg_.M * arg_.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
};
|
||||
|
||||
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
|
||||
stream_config,
|
||||
run_flush_cache,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg_);
|
||||
}
|
||||
else
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
|
||||
0,
|
||||
arg.M * arg.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: Check if this is the right algorithm for minimum_occupancy
|
||||
constexpr index_t minimum_occupancy =
|
||||
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave
|
||||
? (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 &&
|
||||
MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) <= 128 * 128 * 64 * 2)
|
||||
? 2
|
||||
: 1
|
||||
: 2;
|
||||
|
||||
constexpr auto MemoryDataOp =
|
||||
IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd;
|
||||
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
// Tail number always full
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("todo: only v1 & v3 support now");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
|
||||
false,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
|
||||
false,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
|
||||
false,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
|
||||
false,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
// only impl kbatch 1 now
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(!ck::is_xdl_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding ||
|
||||
GemmSpec == GemmSpecialization::KPadding))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const void* p_sorted_token_ids,
|
||||
const void* p_sorted_expert_ids,
|
||||
const void* p_max_token_id,
|
||||
const void* p_a,
|
||||
const void* p_a_scale,
|
||||
const void* p_b,
|
||||
const void* p_b_scale,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_c,
|
||||
index_t NumTokens,
|
||||
index_t TopK,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideScaleA,
|
||||
index_t StrideB,
|
||||
index_t StrideScaleB,
|
||||
std::array<index_t, NumDTensor> StrideDs,
|
||||
index_t StrideC,
|
||||
index_t KBatch,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
return Argument{static_cast<const index_t*>(p_sorted_token_ids),
|
||||
static_cast<const index_t*>(p_sorted_expert_ids),
|
||||
static_cast<const index_t*>(p_max_token_id),
|
||||
static_cast<const ADataType*>(p_a),
|
||||
static_cast<const AScaleDataType*>(p_a_scale),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<const BScaleDataType*>(p_b_scale),
|
||||
p_ds,
|
||||
static_cast<CDataType*>(p_c),
|
||||
NumTokens,
|
||||
TopK,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideScaleA,
|
||||
StrideB,
|
||||
StrideScaleB,
|
||||
StrideDs,
|
||||
StrideC,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_a_scale,
|
||||
const void* p_b,
|
||||
const void* p_b_scale,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideScaleA,
|
||||
index_t StrideB,
|
||||
index_t StrideScaleB,
|
||||
std::array<ck::index_t, NumDTensor> StrideDs,
|
||||
index_t StrideC,
|
||||
index_t KBatch,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
static_cast<const ADataType*>(p_a),
|
||||
static_cast<const AScaleDataType*>(p_a_scale),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<const BScaleDataType*>(p_b_scale),
|
||||
p_ds,
|
||||
static_cast<CDataType*>(p_c),
|
||||
M, // randoms set, no use
|
||||
0,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideScaleA,
|
||||
StrideB,
|
||||
StrideScaleB,
|
||||
StrideDs,
|
||||
StrideC,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
|
||||
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
|
||||
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
|
||||
{BlockGemmPipelineVersion::v1, "v1"},
|
||||
{BlockGemmPipelineVersion::v2, "v2"},
|
||||
{BlockGemmPipelineVersion::v3, "v3"},
|
||||
{BlockGemmPipelineVersion::v4, "v4"},
|
||||
{BlockGemmPipelineVersion::v5, "v5"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceMoeGEmmMx"
|
||||
<< "<"
|
||||
<< getGemmSpecializationString(GemmSpec) << ", "
|
||||
<< std::string(ALayout::name)[0]
|
||||
<< std::string(BLayout::name)[0]
|
||||
<< std::string(CLayout::name)[0]
|
||||
<< ">"
|
||||
<< " BlkSize: "
|
||||
<< BlockSize << ", "
|
||||
<< "BlkTile: "
|
||||
<< MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
|
||||
<< "WaveTile: "
|
||||
<< MPerXDL<<"x"<<NPerXDL << ", "
|
||||
<< "WaveMap: "
|
||||
<< MXdlPerWave<<"x" << NXdlPerWave<<", "
|
||||
<< "VmemReadVec: "
|
||||
<< ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
|
||||
<< "BlkGemmPipelineScheduler: "
|
||||
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
|
||||
<< "BlkGemmPipelineVersion: "
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
|
||||
<< "BlkGemmPipelinePrefetchStages: "
|
||||
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user