Merge branch 'test_copy_fix' of https://github.com/ROCm/composable_kernel into fa_decode_pipeline

This commit is contained in:
aska-0096
2025-07-17 07:24:32 +00:00
430 changed files with 41159 additions and 6951 deletions

View File

@@ -403,10 +403,10 @@ bool bnorm_bwd_nhwc_test(bool do_verification,
return (pass);
};
static const double epsilon = std::numeric_limits<float>::epsilon();
int main(int argc, char* argv[])
{
static const double epsilon = std::numeric_limits<float>::epsilon();
bool pass = true;
if(argc > 1)

View File

@@ -314,11 +314,10 @@ bool bnorm_infer_nhwc_test(bool do_verification,
return (pass);
};
static const double epsilon = std::numeric_limits<float>::epsilon();
int main(int argc, char* argv[])
{
bool pass = true;
static const double epsilon = std::numeric_limits<float>::epsilon();
bool pass = true;
if(argc > 1)
{

View File

@@ -453,12 +453,11 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
return (pass);
};
const double epsilon = std::numeric_limits<float>::epsilon();
static const double averageFactor = 0.1;
int main(int argc, char* argv[])
{
bool pass = true;
const double epsilon = std::numeric_limits<float>::epsilon();
static const double averageFactor = 0.1;
bool pass = true;
if(argc > 1)
{

View File

@@ -453,12 +453,11 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
return (pass);
};
const double epsilon = std::numeric_limits<float>::epsilon();
static const double averageFactor = 0.1;
int main(int argc, char* argv[])
{
bool pass = true;
const double epsilon = std::numeric_limits<float>::epsilon();
static const double averageFactor = 0.1;
bool pass = true;
if(argc > 1)
{

View File

@@ -22,7 +22,7 @@ foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_example_executable(example_moe_gemm1_xdl_pk_i4 moe_gemm1_xdl_pk_i4.cpp)
add_example_executable(example_moe_gemm2_xdl_pk_i4 moe_gemm2_xdl_pk_i4.cpp)
if(CK_hip_VERSION VERSION_LESS_EQUAL 6.3.42132)
if(hip_VERSION_FLAT LESS_EQUAL 600342132)
set(EXAMPLE_COMPILE_OPTIONS)
check_cxx_compiler_flag("-mllvm --amdgpu-enable-max-ilp-scheduling-strategy=1" HAS_MAX_ILP_SCHEDULING_STRATEGY)
if(HAS_MAX_ILP_SCHEDULING_STRATEGY)
@@ -31,8 +31,7 @@ foreach(gpu IN LISTS GPU_TARGETS)
example_compile_options(example_moe_gemm1_xdl_pk_i4 PRIVATE ${EXAMPLE_COMPILE_OPTIONS})
example_compile_options(example_moe_gemm2_xdl_pk_i4 PRIVATE ${EXAMPLE_COMPILE_OPTIONS})
endif()
set(GEMM_OPTIONS)
list(APPEND GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32")
set(GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32")
example_compile_options(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle PRIVATE ${GEMM_OPTIONS})
example_compile_options(example_moe_gemm1_xdl_fp8 PRIVATE ${GEMM_OPTIONS})
example_compile_options(example_moe_gemm2_xdl_fp8 PRIVATE ${GEMM_OPTIONS})
@@ -40,16 +39,25 @@ foreach(gpu IN LISTS GPU_TARGETS)
endif()
endforeach()
set(GEMM_OPTIONS)
list(APPEND GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32")
set(BLOCKSCALE_GEMM_OPTIONS)
set(GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32")
set(BLOCKSCALE_GEMM_OPTIONS )
check_cxx_compiler_flag("-mllvm --misched-bottomup=1" HAS_MISCHED_BOTTOMUP)
check_cxx_compiler_flag("-mllvm --misched-prera-direction=bottomup" HAS_MISCHED_PRERA_DIRECTION)
if(HAS_MISCHED_BOTTOMUP)
list(APPEND BLOCKSCALE_GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm --schedmodel=0 -mllvm --misched-bottomup=1")
elseif(HAS_MISCHED_PRERA_DIRECTION)
list(APPEND BLOCKSCALE_GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm --schedmodel=0 -mllvm --misched-prera-direction=bottomup")
if(hip_VERSION_FLAT LESS 600443483 OR hip_VERSION_FLAT GREATER_EQUAL 700000000)
if(HAS_MISCHED_BOTTOMUP)
list(APPEND BLOCKSCALE_GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm --schedmodel=0 -mllvm --misched-bottomup=1")
elseif(HAS_MISCHED_PRERA_DIRECTION)
list(APPEND BLOCKSCALE_GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm --schedmodel=0 -mllvm --misched-prera-direction=bottomup")
endif()
else()
if(HAS_MISCHED_BOTTOMUP)
list(APPEND BLOCKSCALE_GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm --misched-bottomup=1")
elseif(HAS_MISCHED_PRERA_DIRECTION)
list(APPEND BLOCKSCALE_GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32 -mllvm --misched-prera-direction=bottomup")
endif()
endif()
check_cxx_compiler_flag("-mllvm --amdgpu-sched-strategy=gcn-iterative-max-occupancy-experimental " HAS_MAX_OCCUPANCY_EXPERIMENTAL)
if(HAS_MAX_OCCUPANCY_EXPERIMENTAL)
list(APPEND BLOCKSCALE_GEMM_OPTIONS -mllvm --amdgpu-sched-strategy=gcn-iterative-max-occupancy-experimental)

View File

@@ -139,6 +139,7 @@ static constexpr ck::index_t EVec = 2;
static constexpr ck::index_t D0Vec = 1;
static constexpr ck::index_t D1Vec = 1;
static constexpr ck::index_t D2Vec = 1;
static constexpr bool PerTokenQuant = true;
static constexpr bool MulRoutedWeight = true;
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm
// clang-format off
@@ -169,7 +170,7 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
// MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
2, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S<EVec, D0Vec, D1Vec, D2Vec>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, false, int32_t, A0DataType>;
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, PerTokenQuant, int32_t, A0DataType>;
// kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>;
@@ -197,7 +198,7 @@ int main(int argc, char* argv[])
{
// use default case
}
else if(argc == 3)
else if(argc == 4)
{
// use default case
do_verification = std::stoi(argv[1]);
@@ -238,7 +239,8 @@ int main(int argc, char* argv[])
ck::index_t StrideB = K;
ck::index_t StrideE = N;
constexpr ck::index_t NumDTensor = DsDataType::Size();
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0, 0};
constexpr auto StrideDs = PerTokenQuant ? std::array<ck::index_t, NumDTensor>{1, 1, 0}
: std::array<ck::index_t, NumDTensor>{0, 0, 0};
ck::index_t KBatch = 1;
@@ -279,8 +281,10 @@ int main(int argc, char* argv[])
Tensor<A0DataType> a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
Tensor<D0DataType> d0_t_n(HostTensorDescriptor({tokens, N}, {StrideDs[0], 0}));
Tensor<D1DataType> d1_e_n(HostTensorDescriptor({experts, N}, {1, StrideDs[1]}));
Tensor<D0DataType> d0_t_n(
HostTensorDescriptor({tokens, topk, N}, {StrideDs[0] * topk, StrideDs[0], 0}));
Tensor<D1DataType> d1_e_n(
HostTensorDescriptor({experts, N}, {PerTokenQuant ? StrideDs[1] * N : 1, StrideDs[1]}));
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1}));
Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1}));

View File

@@ -10,6 +10,12 @@ add_example_dependencies(example_gemm_mx example_gemm_mx_bf8)
# add_example_executable(example_gemm_mx_fp8_bf8 gemm_mx_fp8_bf8.cpp)
# add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_bf8)
add_example_executable(example_gemm_mx_fp6 gemm_mx_fp6.cpp)
add_example_dependencies(example_gemm_mx example_gemm_mx_fp6)
add_example_executable(example_gemm_mx_bf6 gemm_mx_bf6.cpp)
add_example_dependencies(example_gemm_mx example_gemm_mx_bf6)
add_example_executable(example_gemm_mx_fp4 gemm_mx_fp4.cpp)
add_example_dependencies(example_gemm_mx example_gemm_mx_fp4)
@@ -22,17 +28,41 @@ add_example_dependencies(example_gemm_mx example_moe_gemm1_xdl_mx_fp4_bns)
add_example_executable(example_moe_gemm2_xdl_mx_fp4_bns moe_gemm2_xdl_mx_fp4_bns.cpp)
add_example_dependencies(example_gemm_mx example_moe_gemm2_xdl_mx_fp4_bns)
add_example_executable(example_moe_gemm1_xdl_mx_fp4 moe_gemm1_xdl_mx_fp4.cpp)
add_example_dependencies(example_gemm_mx example_moe_gemm1_xdl_mx_fp4)
add_example_executable(example_moe_gemm2_xdl_mx_fp4 moe_gemm2_xdl_mx_fp4.cpp)
add_example_dependencies(example_gemm_mx example_moe_gemm2_xdl_mx_fp4)
add_example_executable(example_moe_gemm1_xdl_mx_fp4_bpreshuffle moe_gemm1_xdl_mx_fp4_bpreshuffle.cpp)
add_example_dependencies(example_gemm_mx example_moe_gemm1_xdl_mx_fp4_bpreshuffle)
add_example_executable(example_moe_gemm2_xdl_mx_fp4_bpreshuffle moe_gemm2_xdl_mx_fp4_bpreshuffle.cpp)
add_example_dependencies(example_gemm_mx example_moe_gemm2_xdl_mx_fp4_bpreshuffle)
set(FP4_MXGEMM_OPTIONS)
list(APPEND FP4_MXGEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --amdgpu-use-amdgpu-trackers=1")
example_compile_options(example_gemm_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS})
example_compile_options(example_gemm_mx_fp4_bpreshuffle PRIVATE ${FP4_MXGEMM_OPTIONS})
example_compile_options(example_moe_gemm1_xdl_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS})
example_compile_options(example_moe_gemm2_xdl_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS})
# mx moe B no-shuffling + scale shuffling
example_compile_options(example_moe_gemm1_xdl_mx_fp4_bns PRIVATE ${FP4_MXGEMM_OPTIONS})
example_compile_options(example_moe_gemm2_xdl_mx_fp4_bns PRIVATE ${FP4_MXGEMM_OPTIONS})
# mx moe B no-shuffling + scale shuffling (async loads)
example_compile_options(example_moe_gemm1_xdl_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS})
example_compile_options(example_moe_gemm2_xdl_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS})
# mx moe B shuffling + scale shuffling (async loads)
example_compile_options(example_moe_gemm1_xdl_mx_fp4_bpreshuffle PRIVATE ${FP4_MXGEMM_OPTIONS})
example_compile_options(example_moe_gemm2_xdl_mx_fp4_bpreshuffle PRIVATE ${FP4_MXGEMM_OPTIONS})
set(FP8_MXGEMM_OPTIONS)
list(APPEND FP8_MXGEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32")
example_compile_options(example_gemm_mx_fp8 PRIVATE ${FP8_MXGEMM_OPTIONS})
example_compile_options(example_gemm_mx_bf8 PRIVATE ${FP8_MXGEMM_OPTIONS})
set(FP6_MXGEMM_OPTIONS)
list(APPEND FP6_MXGEMM_OPTIONS -mavx512f)
example_compile_options(example_gemm_mx_fp6 PRIVATE ${FP6_MXGEMM_OPTIONS})
example_compile_options(example_gemm_mx_bf6 PRIVATE ${FP6_MXGEMM_OPTIONS})

View File

@@ -8,14 +8,16 @@ Custom verification parameters:
# arg2: initialization (0=constant values, 1=integer values, 2=decimal values)
# arg3: time kernel (0=no, 1=yes)
# arg4: verbosity (0=no info, 1=verbose info)
# arg5 to 10: M(128x), N(128x), K(64x), StrideA, StrideB, StrideC
# arg5 to 10: M(256x), N(256x), K(512x), StrideA, StrideB, StrideC
# arg11: KBatch
# arg12: warmup runs pre-timing
# arg13: repeat run count for timing
./bin/example_gemm_mx_fp8 1 1 0 1
```
Custom tensor shapes:
```bash
./bin/example_gemm_mx_fp8 1 2 1 0 128 128 256 -1 -1 -1 1
./bin/example_gemm_mx_fp8 1 2 1 0 256 256 512 -1 -1 -1 1 10 10
```
Default invocation:

View File

@@ -0,0 +1,101 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_mx_common.hpp"
using ADataType = ck::bf6x16_pk_t;
using BDataType = ck::bf6x16_pk_t;
using XDataType = ck::e8m0_bexp_t;
using XPackedDataType = int32_t;
using CDataType = ck::half_t;
using AccDataType = float;
using CShuffleDataType = CDataType;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough; // elementwise transformation for A matrix
using BElementOp = PassThrough; // elementwise transformation for B matrix
using CElementOp = PassThrough; // elementwise transformation for C matrix
constexpr ck::index_t DataPackedSize = 16; // Packed representation of data
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 bf6 = 16 bf6x16_pk_t
constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave;
constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3<
ALayout, // ALayout
BLayout, // BLayout
CLayout, // CLayout
ADataType, // ADataType
XPackedDataType, // AScaleDataType
BDataType, // BDataType
XPackedDataType, // BScaleDataType
CDataType, // CDataType
AccDataType, // GemmAccDataType
CShuffleDataType, // CShuffleDataType
AElementOp, // AElementwiseOperation
BElementOp, // BElementwiseOperation
CElementOp, // CElementwiseOperation
GemmSpec, // GemmSpec
ScaleBlockSize, // ScaleBlockSize: Scaling block size
256, // BlockSize: Thread block size
128, // MPerBlock
128, // NPerBlock
KPerBlock, // KPerBlock
1, // AK1
1, // BK1
16, // MPerXDL
16, // NPerXDL
4, // MXdlPerWave
4, // NXdlPerWave
S<16, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector
1, // ABlockTransferDstScalarPerVector_AK1
true, // ABlockLdsExtraM
S<16, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector
1, // BBlockTransferDstScalarPerVector_BK1
true, // BBlockLdsExtraN
2, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
BlkGemmPSched, // BlkGemmPipeSched
BlkGemmPVer, // BlkGemmPipelineVer
ADataType, // ComputeTypeA
BDataType // ComputeTypeB
>;
int main(int argc, char* argv[])
{
return run_mx_gemm_example<DeviceOpInstance,
ADataType,
BDataType,
XDataType,
XPackedDataType,
CDataType,
ALayout,
BLayout,
CLayout,
AElementOp,
BElementOp,
CElementOp,
AccDataType,
CShuffleDataType,
ScaleBlockSize>(argc, argv)
? 0
: -1;
}

View File

@@ -100,8 +100,11 @@ bool parse_cmd_args(int argc,
<< std::endl
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
<< "arg4: verbosity (0=no info, 1=verbose info)" << std::endl
<< "arg5 to 10: M(128x), N(128x), K(256x), StrideA, StrideB, StrideC" << std::endl
<< "arg11: KBatch" << std::endl;
<< "arg5 to 10: M(256x), N(256x), K(512x), StrideA, StrideB, StrideC" << std::endl
<< "arg11: KBatch" << std::endl
<< "arg12: warmup runs pre-timing" << std::endl
<< "arg13: repeat run count for timing" << std::endl;
return false;
}
@@ -245,6 +248,11 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize.");
};
if(K % ck::packed_size_v<ADataType> != 0 || K % ck::packed_size_v<BDataType> != 0)
{
throw std::runtime_error("wrong! K must be multiple of packed size.");
};
// Hardcode scale layouts as per pipeline assumptions
// TODO: Allow user to specify scale layouts
using AScaleLayout = Row;
@@ -292,12 +300,20 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
auto a_data_element = [](float x) {
if constexpr(ck::is_same_v<ADataType, ck::f4x2_pk_t>)
return ck::type_convert<ADataType>(ck::float2_t(x));
else if constexpr(ck::packed_size_v<ADataType> == 32)
return ck::type_convert<ADataType>(ck::float32_t(x));
else if constexpr(ck::packed_size_v<ADataType> == 16)
return ck::type_convert<ADataType>(ck::float16_t(x));
else
return ck::type_convert<ADataType>(x);
};
auto b_data_element = [](float x) {
if constexpr(ck::is_same_v<BDataType, ck::f4x2_pk_t>)
return ck::type_convert<BDataType>(ck::float2_t(x));
else if constexpr(ck::packed_size_v<BDataType> == 32)
return ck::type_convert<BDataType>(ck::float32_t(x));
else if constexpr(ck::packed_size_v<BDataType> == 16)
return ck::type_convert<BDataType>(ck::float16_t(x));
else
return ck::type_convert<BDataType>(x);
};
@@ -307,30 +323,35 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
switch(config.init_method)
{
case 0: // Initializations for development and debugging
ck::utils::FillConstant<ADataType>{a_data_element(1.0f)}(a_m_k);
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(1.0f)}(a_m_k_scale);
ck::utils::FillConstant<ADataType>{a_data_element(0.5f)}(a_m_k);
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(2.0f)}(a_m_k_scale);
ck::utils::FillConstant<BDataType>{b_data_element(2.0f)}(*b_k_n);
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(0.5f)}(b_k_n_scale);
if(config.verbosity > 0)
{
std::cout << "Init A = {1}" << std::endl;
std::cout << "Init A = {0.5}" << std::endl;
std::cout << "Init A scale = {2.0}" << std::endl;
std::cout << "Init B = {0.5}" << std::endl;
std::cout << "Init B scale = {1.0}" << std::endl;
std::cout << "Init B = {2.0}" << std::endl;
std::cout << "Init B scale = {0.5}" << std::endl;
std::cout << "Expect C = {K}" << std::endl;
}
break;
case 1:
a_m_k.GenerateTensorDistr(int_distr{-5, 6}); // Z[-5,5]
b_k_n->GenerateTensorDistr(int_distr{-5, 6}); // Z[-5,5]
a_m_k.GenerateTensorDistr(
int_distr{-5, 5}, ck::identity{}, std::minstd_rand(time(nullptr))); // Z[-5,5]
b_k_n->GenerateTensorDistr(int_distr{-5, 5}); // Z[-5,5]
static_assert(ck::is_same_v<XDataType, ck::e8m0_bexp_t>);
a_m_k_scale.GenerateTensorDistr(int_distr{120, 129}); // scales: {0.25, 0.5, 1, 2}
b_k_n_scale.GenerateTensorDistr(int_distr{125, 129}); // scales: {0.25, 0.5, 1, 2}
a_m_k_scale.GenerateTensorDistr(int_distr{125, 128}); // scales: {0.25, 0.5, 1, 2}
b_k_n_scale.GenerateTensorDistr(int_distr{125, 128}); // scales: {0.25, 0.5, 1, 2}
break;
case 2:
a_m_k.GenerateTensorDistr(float_distr{-2.0, 2.0});
a_m_k.GenerateTensorDistr(
float_distr{-2.0, 2.0}, ck::identity{}, std::minstd_rand(time(nullptr))); // R[-2,2]
a_m_k_scale.GenerateTensorDistr(float_distr{powf(2.0f, -125.0f), 1.0f});
b_k_n->GenerateTensorDistr(float_distr{-2.0, 2.0});

View File

@@ -0,0 +1,99 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "gemm_mx_common.hpp"
using ADataType = ck::f6x16_pk_t;
using BDataType = ck::f6x16_pk_t;
using XDataType = ck::e8m0_bexp_t;
using CDataType = ck::half_t;
using AccDataType = float;
using CShuffleDataType = CDataType;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough; // elementwise transformation for A matrix
using BElementOp = PassThrough; // elementwise transformation for B matrix
using CElementOp = PassThrough; // elementwise transformation for C matrix
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
constexpr ck::index_t KPerBlock = 256 / ck::packed_size_v<ADataType>; // K dimension size per block
constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave;
constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v1;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3<
ALayout, // ALayout
BLayout, // BLayout
CLayout, // CLayout
ADataType, // ADataType
XDataType, // AScaleDataType
BDataType, // BDataType
XDataType, // BScaleDataType
CDataType, // CDataType
AccDataType, // GemmAccDataType
CShuffleDataType, // CShuffleDataType
AElementOp, // AElementwiseOperation
BElementOp, // BElementwiseOperation
CElementOp, // CElementwiseOperation
GemmSpec, // GemmSpec
ScaleBlockSize, // ScaleBlockSize: Scaling block size
256, // BlockSize: Number of threads per block
128, // MPerBlock
128, // NPerBlock
KPerBlock, // KPerBlock
1, // AK1 number of elements to read at a time when transferring from global memory to LDS
1, // BK1
16, // MPerXDL
16, // NPerXDL
4, // MXdlPerWave
4, // NXdlPerWave
S<16, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
1, // ABlockTransferSrcScalarPerVector
16, // ABlockTransferDstScalarPerVector_AK1
true, // ABlockLdsExtraM
S<16, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
1, // BBlockTransferSrcScalarPerVector
16, // BBlockTransferDstScalarPerVector_BK1
true, // BBlockLdsExtraN
2, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
BlkGemmPSched, // BlkGemmPipeSched
BlkGemmPVer, // BlkGemmPipelineVer
ADataType, // ComputeTypeA
BDataType // ComputeTypeB
>;
int main(int argc, char* argv[])
{
return run_mx_gemm_example<DeviceOpInstance,
ADataType,
BDataType,
XDataType,
XDataType,
CDataType,
ALayout,
BLayout,
CLayout,
AElementOp,
BElementOp,
CElementOp,
AccDataType,
CShuffleDataType,
ScaleBlockSize>(argc, argv)
? 0
: -1;
}

View File

@@ -0,0 +1,548 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm1.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F4 = ck::f4x2_pk_t;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using XDataType = ck::e8m0_bexp_t;
using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = F4;
using A1DataType = XPackedDataType;
using B0DataType = F4;
using B1DataType = XPackedDataType;
using EDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F16;
using D0DataType = F32;
using D1DataType = F32;
using D2DataType = F32;
using DsDataType = ck::Tuple<D0DataType, D1DataType, D2DataType>;
using A0Layout = Row;
using B0Layout = Col;
using ELayout = Row;
using D0Layout = Row;
using D1Layout = Col;
using D2Layout = ELayout;
using DsLayout = ck::Tuple<D0Layout, D1Layout, D2Layout>;
// d0: ascale, d1: bscale, d2:expert weight
struct MulABScaleExpertWeight
{
template <typename E, typename C, typename D0, typename D1, typename D2>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const;
// for real kernel use
template <>
__host__ __device__ constexpr void operator()<EDataType, F16, float, float, float>(
EDataType& e, const F16& c, const float& d0, const float& d1, const float& d2) const
{
(void)d0;
(void)d1;
(void)d2;
e = ck::type_convert<EDataType>(c);
}
// for reference cpu
template <>
__host__ __device__ constexpr void operator()<float, float, float, float, float>(
float& e, const float& c, const float& d0, const float& d1, const float& d2) const
{
// for reference cpu
(void)d0;
(void)d1;
(void)d2;
e = ck::type_convert<EDataType>(c);
}
};
using CDEElementOp = MulABScaleExpertWeight;
// A, B Scale preshuffle
template <bool KLast>
void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K)
{
int MNXdlPack = 2;
int KXdlPack = 2;
int XdlMNThread = 16;
int XdlKThread = 64 / XdlMNThread;
int K0 = K / KXdlPack / XdlKThread; // KRepeat
// The 4 16x128 building blocks will be packed into 1 32x256 for F4
// The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4
// unfold the MN32xK(256/32) scale buffer
// 4 16 2 2
// To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack
// Then, MNRepeat->KRepeat
for(int n = 0; n < MN; ++n)
{
for(int k = 0; k < K; ++k)
{
int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat
int tempn = n % (XdlMNThread * MNXdlPack);
int n1 = tempn % XdlMNThread; // i XdlMNThread
int n2 = tempn / XdlMNThread; // i MNXdlPack
int k0 = k / (XdlKThread * KXdlPack); // i KRepeat
int tempk = k % (XdlKThread * KXdlPack);
int k1 = tempk % XdlKThread; // i XdlKThread
int k2 = tempk / XdlKThread; // i KXdlPack
int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 +
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack +
k2 * MNXdlPack + n2;
// src[n * K + k] = ck::type_convert<ck::e8m0_bexp_t>(static_cast<float>(powf(2.0f, n2 +
// k2 * MNXdlPack)));
if constexpr(KLast)
dst[outputIndex] = src[n * K + k];
else
dst[outputIndex] = src[k * MN + n];
}
}
}
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = MulABScaleExpertWeight;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
constexpr ck::index_t DataPackedSize = 2; // Packed representation of data
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
static constexpr ck::index_t Nswizzle = false;
static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul
static constexpr ck::index_t MPerBlock = 128;
static constexpr ck::index_t NPerBlock = 64;
static constexpr ck::index_t BlockSize = 256;
static constexpr bool MulRoutedWeight = true;
// clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMX<
A0Layout, B0Layout, DsLayout, ELayout,
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
ScaleBlockSize, BlockSize,
MPerBlock, NPerBlock, KPerBlock,
16, 16,
16, 16,
4, 2,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
2, 2, S<1, 32, 1, 8>, S<8, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3,
ActOP, Nswizzle, true, MulRoutedWeight, ck::index_t, A0DataType>;
// clang-format on
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
// per expert:
// GEMM shape
constexpr ck::index_t sorted_tile_num = 13;
constexpr ck::index_t valid_tile_num = sorted_tile_num;
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;
ck::index_t N = 6144;
ck::index_t K = 4096;
ck::index_t experts = 8;
ck::index_t tokens = 832;
ck::index_t topk = 2;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
// use default case
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 7)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
N = std::stoi(argv[4]);
K = std::stoi(argv[5]);
tokens = std::stoi(argv[6]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 6: N, K, tokens\n");
exit(0);
}
if(K % ScaleBlockSize != 0)
{
throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize.");
};
ck::index_t StrideA = K;
ck::index_t StrideB = K;
ck::index_t StrideE = N;
ck::index_t Scale_Stride_AM = (K + ScaleBlockSize - 1) / ScaleBlockSize;
ck::index_t Scale_Stride_BN = (K + ScaleBlockSize - 1) / ScaleBlockSize;
constexpr ck::index_t NumDTensor = DsDataType::Size();
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0, 0};
ck::index_t KBatch = 1;
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({sorted_tile_num + 1}));
max_token_id.mData[0] = valid_size;
if(tokens * topk > valid_size)
{
printf("err config, tokens * topk > valid_size\n");
exit(-1);
}
for(int i = 0; i < sorted_tile_num; i++)
{
expert_ids.mData[i] = i / ck::math::integer_divide_ceil(valid_tile_num, experts);
}
int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num;
int tokenid = 0;
for(int i = 0; i < sorted_size; i++)
{
int tile_off = i % MPerBlock;
if(tile_off < token_per_tile)
{
sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24);
tokenid++;
}
else
{
sorted_token_ids.mData[i] = tokens;
}
}
expert_ids.savetxt("expert_ids.txt", "int");
sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
Tensor<A0DataType> a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1}));
Tensor<XDataType> a1_t_k(HostTensorDescriptor(
{tokens, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}));
Tensor<XDataType> b1_e_n_k(
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2},
{(N * 2 * Scale_Stride_BN), 1, Scale_Stride_BN}));
// A, B Scale preshuffle
Tensor<XDataType> a_scale_sorted(HostTensorDescriptor(
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
Tensor<XDataType> a_scale_preshuffled(HostTensorDescriptor(
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
Tensor<XDataType> b_scale_preshuffled(
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2},
{N * 2 * Scale_Stride_BN, 1, Scale_Stride_BN}));
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
Tensor<EDataType> e_t_k_n_host_result(
HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
Tensor<EDataType> e_t_k_n_device_result(
HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
e_t_k_n_device_result.SetZero();
std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl;
std::cout << "a1_t_k: " << a1_t_k.mDesc << std::endl;
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl;
std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl;
std::cout << "e_t_k_n: " << e_t_k_n_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-1, 1});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-1, 1});
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0, 1.0});
break;
case 2:
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
a1_t_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{0.1f});
break;
case 3:
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-1, 1});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-1, 1});
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 4:
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 5.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 5:
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{1});
break;
case 6:
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 7:
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{0.5f});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{1.5f});
a1_t_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{1.0f});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{1.0f});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{0.1f});
break;
default:
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
}
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.GetElementSpaceSize());
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.GetElementSpaceSize());
DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.GetElementSpaceSize());
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.GetElementSpaceSize());
DeviceMem a1_device_buf(sizeof(XDataType) * a_scale_sorted.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.GetElementSpaceSize());
DeviceMem b1_device_buf(sizeof(XDataType) * b1_e_n_k.GetElementSpaceSize());
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_t_k_n_device_result.GetElementSpaceSize());
// A scale sorted
for(int i = 0; i < sorted_size; i++)
{
int token_id = sorted_token_ids.mData[i] & 0x00FFFFFF;
for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; k++)
{
if(token_id == tokens)
{
a_scale_sorted(i, k) = ck::type_convert<XDataType>(0);
}
else
{
a_scale_sorted(i, k) = a1_t_k(token_id, k);
}
}
}
// A/B scale shuffle
preShuffleScaleBuffer<ck::is_same_v<A0Layout, Row>>(a_scale_sorted.mData.data(),
a_scale_preshuffled.mData.data(),
sorted_size,
K / ScaleBlockSize);
preShuffleScaleBuffer<ck::is_same_v<B0Layout, Col>>(b1_e_n_k.mData.data(),
b_scale_preshuffled.mData.data(),
N * 2 * experts,
K / ScaleBlockSize);
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
expert_ids_dev.ToDevice(expert_ids.mData.data());
max_token_id_dev.ToDevice(max_token_id.mData.data());
a0_device_buf.ToDevice(a0_t_k.mData.data());
b0_device_buf.ToDevice(b0_e_n_k.mData.data());
a1_device_buf.ToDevice(a_scale_preshuffled.mData.data());
b1_device_buf.ToDevice(b_scale_preshuffled.mData.data());
d2_device_buf.ToDevice(d2_e_n.mData.data());
e_device_buf.ToDevice(e_t_k_n_device_result.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
// do GEMM
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument = device_op.MakeArgument(
sorted_token_ids_dev.GetDeviceBuffer(),
expert_ids_dev.GetDeviceBuffer(),
max_token_id_dev.GetDeviceBuffer(),
a0_device_buf.GetDeviceBuffer(),
a1_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(),
b1_device_buf.GetDeviceBuffer(),
std::array<const void*, NumDTensor>{nullptr, nullptr, d2_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
tokens,
topk,
sorted_size,
N,
K,
StrideA,
Scale_Stride_AM,
StrideB,
Scale_Stride_BN,
StrideDs,
StrideE,
KBatch,
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"))
{
std::cout << "This kernel support gfx942 and gfx950 only" << std::endl;
}
if(time_kernel)
{
// not result correct here because output buf not setzero
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop =
// FMA * tokens * N * (Gate+Up) * topk * K +
// FMA * tokens * N * (Gate+Up) * topk * (K/BlockScale)
std::size_t(2) * tokens * N * 2 * topk * K +
std::size_t(2) * tokens * N * 2 * topk * K / ScaleBlockSize;
std::size_t num_btype = sizeof(A0DataType) / 2 * tokens * topk * K +
sizeof(B0DataType) / 2 * K * N * 2 * experts +
sizeof(XDataType) * tokens * topk * K / ScaleBlockSize +
sizeof(XDataType) * K / ScaleBlockSize * N * 2 * experts +
sizeof(EDataType) * tokens * topk * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << device_op.GetTypeString() << std::endl;
}
if(do_verification)
{
// gemm2 use atomic, so need to reinit outputs
e_device_buf.ToDevice(e_t_k_n_device_result.mData.data());
invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1});
Tensor<float> c_t_k_n({tokens, topk, N}, {topk * N, N, 1});
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceMoeMXGemm1<A0DataType,
XDataType,
B0DataType,
XDataType,
float, // CShuffleDataType,
D2DataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough,
ActOP,
MulRoutedWeight>;
auto ref_moe_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_moe_gemm.MakeInvoker();
auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids,
expert_ids,
max_token_id,
MPerBlock,
a0_t_k,
a1_t_k,
b0_e_n_k,
b1_e_n_k,
d2_e_n,
c_t_k_n,
PassThrough{},
PassThrough{},
PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < valid_size; ++m)
{
const int fuse_t = sorted_token_ids.mData[m];
const int t = fuse_t & 0xffffff;
const int topk_id = (fuse_t & 0xff000000) >> 24;
if(t >= tokens)
{
continue;
}
for(int n = 0; n < N; ++n)
{
e_t_k_n_host_result(t, topk_id, n) =
ck::type_convert<EDataType>(c_t_k_n(t, topk_id, n));
}
}
e_device_buf.FromDevice(e_t_k_n_device_result.mData.data());
auto status =
ck::utils::check_err(
e_t_k_n_device_result, e_t_k_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1)
? 0
: 1;
if(status == 0)
{
printf("Validation Pass.\n");
}
return status;
}
return 0;
}

View File

@@ -0,0 +1,574 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bpreshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm1.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F4 = ck::f4x2_pk_t;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using XDataType = ck::e8m0_bexp_t;
using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t
using I64 = int64_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = F4;
using A1DataType = XPackedDataType;
using B0DataType = F4;
using B1DataType = XPackedDataType;
using EDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F16;
using D0DataType = F32;
using D1DataType = F32;
using D2DataType = F32;
using DsDataType = ck::Tuple<D0DataType, D1DataType, D2DataType>;
using A0Layout = Row;
using B0Layout = Col;
using ELayout = Row;
using D0Layout = Row;
using D1Layout = Col;
using D2Layout = ELayout;
using DsLayout = ck::Tuple<D0Layout, D1Layout, D2Layout>;
// d0: ascale, d1: bscale, d2:expert weight
struct MulABScaleExpertWeight
{
template <typename E, typename C, typename D0, typename D1, typename D2>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const;
// for real kernel use
template <>
__host__ __device__ constexpr void operator()<EDataType, F16, float, float, float>(
EDataType& e, const F16& c, const float& d0, const float& d1, const float& d2) const
{
(void)d0;
(void)d1;
(void)d2;
e = ck::type_convert<EDataType>(c);
}
// for reference cpu
template <>
__host__ __device__ constexpr void operator()<float, float, float, float, float>(
float& e, const float& c, const float& d0, const float& d1, const float& d2) const
{
// for reference cpu
(void)d0;
(void)d1;
(void)d2;
e = ck::type_convert<EDataType>(c);
}
};
using CDEElementOp = MulABScaleExpertWeight;
// B preshuffle
void preShuffleBuffer(const F4* src, F4* dst, int N, int K, int NXdl)
{
int KPack = 16;
int NLane = NXdl;
int KLane = 64 / NLane;
int K_pk = K / 2;
int K0 = K_pk / (KLane * KPack);
// K -> K0 KLane KPack
// N -> N0 NLane
// N, K -> N0 K0 KLane NLane KPack
I64 tempk;
for(I64 n = 0; n < N; ++n)
{
for(I64 k = 0; k < K_pk; ++k)
{
I64 n0 = n / NLane;
I64 n1 = n % NLane;
I64 k0 = k / (KLane * KPack);
tempk = k % (KLane * KPack);
I64 k1 = tempk / KPack;
I64 k2 = tempk % KPack;
I64 outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
k1 * KPack * NLane + n1 * KPack + k2;
dst[outputIndex] = src[n * K_pk + k];
}
}
}
// A, B Scale preshuffle
template <bool KLast>
void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K)
{
int MNXdlPack = 2;
int KXdlPack = 2;
int XdlMNThread = 16;
int XdlKThread = 64 / XdlMNThread;
int K0 = K / KXdlPack / XdlKThread; // KRepeat
// The 4 16x128 building blocks will be packed into 1 32x256 for F4
// The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4
// unfold the MN32xK(256/32) scale buffer
// 4 16 2 2
// To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack
// Then, MNRepeat->KRepeat
for(int n = 0; n < MN; ++n)
{
for(int k = 0; k < K; ++k)
{
int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat
int tempn = n % (XdlMNThread * MNXdlPack);
int n1 = tempn % XdlMNThread; // i XdlMNThread
int n2 = tempn / XdlMNThread; // i MNXdlPack
int k0 = k / (XdlKThread * KXdlPack); // i KRepeat
int tempk = k % (XdlKThread * KXdlPack);
int k1 = tempk % XdlKThread; // i XdlKThread
int k2 = tempk / XdlKThread; // i KXdlPack
int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 +
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack +
k2 * MNXdlPack + n2;
// src[n * K + k] = ck::type_convert<ck::e8m0_bexp_t>(static_cast<float>(powf(2.0f, n2 +
// k2 * MNXdlPack)));
if constexpr(KLast)
dst[outputIndex] = src[n * K + k];
else
dst[outputIndex] = src[k * MN + n];
}
}
}
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = MulABScaleExpertWeight;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
constexpr ck::index_t DataPackedSize = 2; // Packed representation of data
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
static constexpr ck::index_t Nswizzle = false;
static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul
static constexpr ck::index_t MPerBlock = 128;
static constexpr bool MulRoutedWeight = true;
// clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMXBPreShuffle<
A0Layout, B0Layout, DsLayout, ELayout,
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
ScaleBlockSize, 256,
MPerBlock, 64, KPerBlock,
16, 16,
16, 16,
4, 2,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
2, 2, S<1, 32, 1, 8>, S<8, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, true, MulRoutedWeight, ck::index_t, A0DataType>;
// clang-format on
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
// per expert:
// GEMM shape
constexpr ck::index_t sorted_tile_num = 13;
constexpr ck::index_t valid_tile_num = sorted_tile_num;
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;
ck::index_t N = 6144;
ck::index_t K = 4096;
ck::index_t experts = 8;
ck::index_t tokens = 832;
ck::index_t topk = 2;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
// use default case
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 7)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
N = std::stoi(argv[4]);
K = std::stoi(argv[5]);
tokens = std::stoi(argv[6]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 6: N, K, tokens\n");
exit(0);
}
if(K % ScaleBlockSize != 0)
{
throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize.");
};
ck::index_t StrideA = K;
ck::index_t StrideB = K;
ck::index_t StrideE = N;
ck::index_t Scale_Stride_AM = (K + ScaleBlockSize - 1) / ScaleBlockSize;
ck::index_t Scale_Stride_BN = (K + ScaleBlockSize - 1) / ScaleBlockSize;
constexpr ck::index_t NumDTensor = DsDataType::Size();
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0, 0};
ck::index_t KBatch = 1;
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({sorted_tile_num + 1}));
max_token_id.mData[0] = valid_size;
if(tokens * topk > valid_size)
{
printf("err config, tokens * topk > valid_size\n");
exit(-1);
}
for(int i = 0; i < sorted_tile_num; i++)
{
expert_ids.mData[i] = i / ck::math::integer_divide_ceil(valid_tile_num, experts);
}
int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num;
int tokenid = 0;
for(int i = 0; i < sorted_size; i++)
{
int tile_off = i % MPerBlock;
if(tile_off < token_per_tile)
{
sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24);
tokenid++;
}
else
{
sorted_token_ids.mData[i] = tokens;
}
}
Tensor<A0DataType> a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1}));
Tensor<XDataType> a1_t_k(HostTensorDescriptor(
{tokens, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}));
Tensor<XDataType> b1_e_n_k(
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2},
{(N * 2 * Scale_Stride_BN), 1, Scale_Stride_BN}));
// B preshuffle
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}));
// A, B Scale preshuffle
Tensor<XDataType> a_scale_sorted(HostTensorDescriptor(
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
Tensor<XDataType> a_scale_preshuffled(HostTensorDescriptor(
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
Tensor<XDataType> b_scale_preshuffled(
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N * 2},
{N * 2 * Scale_Stride_BN, 1, Scale_Stride_BN}));
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
Tensor<EDataType> e_t_k_n_host_result(
HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
Tensor<EDataType> e_t_k_n_device_result(
HostTensorDescriptor({tokens, topk, N}, {topk * N, N, 1}));
e_t_k_n_device_result.SetZero();
std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl;
std::cout << "a1_t_k: " << a1_t_k.mDesc << std::endl;
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl;
std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl;
std::cout << "e_t_k_n: " << e_t_k_n_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-1, 1});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-1, 1});
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0, 1.0});
break;
case 2:
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
a1_t_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{0.1f});
break;
case 3:
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-1, 1});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-1, 1});
a1_t_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{0.1f});
break;
case 4:
a0_t_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{0.1f});
break;
case 5:
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{0.1f});
break;
case 6:
a0_t_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
default:
a0_t_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
a1_t_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
}
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.GetElementSpaceSize());
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.GetElementSpaceSize());
DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.GetElementSpaceSize());
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.GetElementSpaceSize());
DeviceMem a1_device_buf(sizeof(XDataType) * a_scale_sorted.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.GetElementSpaceSize());
DeviceMem b1_device_buf(sizeof(XDataType) * b1_e_n_k.GetElementSpaceSize());
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_t_k_n_device_result.GetElementSpaceSize());
// A scale sorted
for(int i = 0; i < sorted_size; i++)
{
int token_id = sorted_token_ids.mData[i] & 0x00FFFFFF;
for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; k++)
{
if(token_id == tokens)
{
a_scale_sorted(i, k) = ck::type_convert<XDataType>(0);
}
else
{
a_scale_sorted(i, k) = a1_t_k(token_id, k);
}
}
}
// A/B scale shuffle
preShuffleScaleBuffer<ck::is_same_v<A0Layout, Row>>(a_scale_sorted.mData.data(),
a_scale_preshuffled.mData.data(),
sorted_size,
K / ScaleBlockSize);
preShuffleScaleBuffer<ck::is_same_v<B0Layout, Col>>(b1_e_n_k.mData.data(),
b_scale_preshuffled.mData.data(),
N * 2 * experts,
K / ScaleBlockSize);
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
expert_ids_dev.ToDevice(expert_ids.mData.data());
max_token_id_dev.ToDevice(max_token_id.mData.data());
a0_device_buf.ToDevice(a0_t_k.mData.data());
a1_device_buf.ToDevice(a_scale_preshuffled.mData.data());
b1_device_buf.ToDevice(b_scale_preshuffled.mData.data());
d2_device_buf.ToDevice(d2_e_n.mData.data());
e_device_buf.ToDevice(e_t_k_n_device_result.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
// do GEMM
auto device_op = DeviceOpInstance{};
preShuffleBuffer(b0_e_n_k.mData.data(),
b0_preshuffled.mData.data(),
N * 2 * experts,
K,
device_op.GetPreShuffleParameters());
b0_device_buf.ToDevice(b0_preshuffled.mData.data());
auto invoker = device_op.MakeInvoker();
auto argument = device_op.MakeArgument(
sorted_token_ids_dev.GetDeviceBuffer(),
expert_ids_dev.GetDeviceBuffer(),
max_token_id_dev.GetDeviceBuffer(),
a0_device_buf.GetDeviceBuffer(),
a1_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(),
b1_device_buf.GetDeviceBuffer(),
std::array<const void*, NumDTensor>{nullptr, nullptr, d2_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
tokens,
topk,
sorted_size,
N,
K,
StrideA,
Scale_Stride_AM,
StrideB,
Scale_Stride_BN,
StrideDs,
StrideE,
KBatch,
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"))
{
std::cout << "This kernel support gfx942 and gfx950 only" << std::endl;
}
if(time_kernel)
{
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
std::size_t flop =
// FMA * tokens * N * (Gate+Up) * topk * K +
// FMA * tokens * N * (Gate+Up) * topk * (K/BlockScale)
std::size_t(2) * tokens * N * 2 * topk * K +
std::size_t(2) * tokens * N * 2 * topk * K / ScaleBlockSize;
std::size_t num_btype = sizeof(A0DataType) / 2 * tokens * topk * K +
sizeof(B0DataType) / 2 * K * N * 2 * experts +
sizeof(XDataType) * tokens * topk * K / ScaleBlockSize +
sizeof(XDataType) * K / ScaleBlockSize * N * 2 * experts +
sizeof(EDataType) * tokens * topk * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << device_op.GetTypeString() << std::endl;
}
if(do_verification)
{
invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1});
Tensor<float> c_t_k_n({tokens, topk, N}, {topk * N, N, 1});
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceMoeMXGemm1<A0DataType,
XDataType,
B0DataType,
XDataType,
float, // CShuffleDataType,
D2DataType,
AccDataType,
PassThrough,
PassThrough,
PassThrough,
ActOP,
MulRoutedWeight>;
auto ref_moe_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_moe_gemm.MakeInvoker();
auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids,
expert_ids,
max_token_id,
MPerBlock,
a0_t_k,
a1_t_k,
b0_e_n_k,
b1_e_n_k,
d2_e_n,
c_t_k_n,
PassThrough{},
PassThrough{},
PassThrough{});
ref_invoker.Run(ref_argument);
for(int m = 0; m < valid_size; ++m)
{
const int fuse_t = sorted_token_ids.mData[m];
const int t = fuse_t & 0xffffff;
const int topk_id = (fuse_t & 0xff000000) >> 24;
if(t >= tokens)
{
continue;
}
for(int n = 0; n < N; ++n)
{
e_t_k_n_host_result(t, topk_id, n) =
ck::type_convert<EDataType>(c_t_k_n(t, topk_id, n));
}
}
e_device_buf.FromDevice(e_t_k_n_device_result.mData.data());
auto status =
ck::utils::check_err(
e_t_k_n_device_result, e_t_k_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1)
? 0
: 1;
if(status == 0)
{
printf("Validation Pass.\n");
}
return status;
}
return 0;
}

View File

@@ -0,0 +1,542 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm2.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F4 = ck::f4x2_pk_t;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using XDataType = ck::e8m0_bexp_t;
using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = F4;
using A1DataType = XPackedDataType;
using B0DataType = F4;
using B1DataType = XPackedDataType;
using EDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F16;
using D0DataType = F32;
using D1DataType = F32;
using D2DataType = F32;
using DsDataType = ck::Tuple<D0DataType, D1DataType, D2DataType>;
using A0Layout = Row;
using B0Layout = Col;
using ELayout = Row;
using D0Layout = Row;
using D1Layout = Col;
using D2Layout = ELayout;
using DsLayout = ck::Tuple<D0Layout, D1Layout, D2Layout>;
// d0: ascale, d1: bscale, d2:expert weight
struct MulABScaleExpertWeight
{
template <typename E, typename C, typename D0, typename D1, typename D2>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const;
// for real kernel use
template <>
__host__ __device__ constexpr void operator()<EDataType, F16, float, float, float>(
EDataType& e, const F16& c, const float& d0, const float& d1, const float& d2) const
{
(void)d0;
(void)d1;
(void)d2;
e = ck::type_convert<EDataType>(c);
}
// for reference cpu
template <>
__host__ __device__ constexpr void operator()<float, float, float, float, float>(
float& e, const float& c, const float& d0, const float& d1, const float& d2) const
{
// for reference cpu
e = ck::type_convert<EDataType>(c * d0 * d1 * d2);
}
};
using CDEElementOp = MulABScaleExpertWeight;
// A, B Scale preshuffle
template <bool KLast>
void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K)
{
int MNXdlPack = 2;
int KXdlPack = 2;
int XdlMNThread = 16;
int XdlKThread = 64 / XdlMNThread;
int K0 = K / KXdlPack / XdlKThread; // KRepeat
// The 4 16x128 building blocks will be packed into 1 32x256 for F4
// The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4
// unfold the MN32xK(256/32) scale buffer
// 4 16 2 2
// To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack
// Then, MNRepeat->KRepeat
for(int n = 0; n < MN; ++n)
{
for(int k = 0; k < K; ++k)
{
int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat
int tempn = n % (XdlMNThread * MNXdlPack);
int n1 = tempn % XdlMNThread; // i XdlMNThread
int n2 = tempn / XdlMNThread; // i MNXdlPack
int k0 = k / (XdlKThread * KXdlPack); // i KRepeat
int tempk = k % (XdlKThread * KXdlPack);
int k1 = tempk % XdlKThread; // i XdlKThread
int k2 = tempk / XdlKThread; // i KXdlPack
int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 +
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack +
k2 * MNXdlPack + n2;
// src[n * K + k] = ck::type_convert<ck::e8m0_bexp_t>(static_cast<float>(powf(2.0f, n2 +
// k2 * MNXdlPack)));
if constexpr(KLast)
dst[outputIndex] = src[n * K + k];
else
dst[outputIndex] = src[k * MN + n];
}
}
}
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = MulABScaleExpertWeight;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
constexpr ck::index_t DataPackedSize = 2; // Packed representation of data
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
static constexpr ck::index_t MPerBlock = 128;
static constexpr bool MulRoutedWeight = true;
// clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMX<
A0Layout, B0Layout, DsLayout, ELayout,
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
ScaleBlockSize, 256,
MPerBlock, 128, KPerBlock,
16, 16,
16, 16,
4, 4,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
2, 4, S<1, 4, 1, 64>, S<2, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>;
// clang-format on
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
// per expert:
// GEMM shape
constexpr ck::index_t sorted_tile_num = 13;
constexpr ck::index_t valid_tile_num = sorted_tile_num;
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;
ck::index_t N = 6144;
ck::index_t K = 4096;
ck::index_t experts = 8;
ck::index_t tokens = 832;
ck::index_t topk = 2;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
// use default case
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 7)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
N = std::stoi(argv[4]);
K = std::stoi(argv[5]);
tokens = std::stoi(argv[6]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 6: N, K, tokens\n");
exit(0);
}
if(K % ScaleBlockSize != 0)
{
throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize.");
};
ck::index_t StrideA = K;
ck::index_t StrideB = K;
ck::index_t StrideE = N;
ck::index_t Scale_Stride_AM = (K + ScaleBlockSize - 1) / ScaleBlockSize;
ck::index_t Scale_Stride_BN = (K + ScaleBlockSize - 1) / ScaleBlockSize;
constexpr ck::index_t NumDTensor = DsDataType::Size();
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0, 0};
ck::index_t KBatch = 1;
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1}));
max_token_id.mData[0] = valid_size;
// int eids[] = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 3, 3, 3};
int eids[sorted_tile_num]{};
for(int i = 0; i < sorted_tile_num; i++)
{
if(i < valid_tile_num)
{
eids[i] = (i * experts) / valid_tile_num;
}
else
{
eids[i] = 3;
}
}
for(int i = 0; i < sorted_tile_num; i++)
{
expert_ids.mData[i] = eids[i];
}
if(tokens * topk > valid_size)
{
printf("err config, tokens * topk > valid_size\n");
exit(-1);
}
int token_per_tile = tokens * topk / valid_tile_num;
int tokenid = 0;
for(int i = 0; i < sorted_size; i++)
{
int tile_off = i % MPerBlock;
if(tile_off < token_per_tile)
{
sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24);
tokenid++;
}
else
{
sorted_token_ids.mData[i] = tokens;
}
}
expert_ids.savetxt("expert_ids.txt", "int");
sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
Tensor<A0DataType> a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1}));
Tensor<XDataType> a1_t_k_k(
HostTensorDescriptor({tokens, topk, (K + ScaleBlockSize - 1) / ScaleBlockSize},
{(topk * Scale_Stride_AM), Scale_Stride_AM, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
Tensor<XDataType> b1_e_n_k(
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N},
{(N * Scale_Stride_BN), 1, Scale_Stride_BN}));
// A, B Scale preshuffle
Tensor<XDataType> a_scale_sorted(HostTensorDescriptor(
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
Tensor<XDataType> a_scale_preshuffled(HostTensorDescriptor(
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
Tensor<XDataType> b_scale_preshuffled(
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N},
{N * Scale_Stride_BN, 1, Scale_Stride_BN}));
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1}));
Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1}));
e_t_n_device_result.SetZero();
std::cout << "a0_t_k_k: " << a0_t_k_k.mDesc << std::endl;
std::cout << "a1_t_k_k: " << a1_t_k_k.mDesc << std::endl;
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl;
std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl;
std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-1, 1});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-1, 1});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0, 1.0});
break;
case 2:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 3:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 4:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 5.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 5:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 6:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 7:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 8:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
default:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
}
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.GetElementSpaceSize());
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.GetElementSpaceSize());
DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.GetElementSpaceSize());
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k_k.GetElementSpaceSize());
DeviceMem a1_device_buf(sizeof(XDataType) * a_scale_sorted.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.GetElementSpaceSize());
DeviceMem b1_device_buf(sizeof(XDataType) * b1_e_n_k.GetElementSpaceSize());
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.GetElementSpaceSize());
// d2_e_n.savetxt("weight.txt", "int");
// A scale sorted
for(int i = 0; i < sorted_size; i++)
{
int token_id = sorted_token_ids.mData[i] & 0x00FFFFFF;
int topk_id = (sorted_token_ids.mData[i] >> 24) & 0x000000FF;
for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; k++)
{
if(token_id == tokens)
{
a_scale_sorted(i, k) = ck::type_convert<XDataType>(0);
}
else
{
a_scale_sorted(i, k) = a1_t_k_k(token_id, topk_id, k);
}
}
}
preShuffleScaleBuffer<ck::is_same_v<A0Layout, Row>>(a_scale_sorted.mData.data(),
a_scale_preshuffled.mData.data(),
sorted_size,
K / ScaleBlockSize);
preShuffleScaleBuffer<ck::is_same_v<B0Layout, Col>>(
b1_e_n_k.mData.data(), b_scale_preshuffled.mData.data(), N * experts, K / ScaleBlockSize);
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
expert_ids_dev.ToDevice(expert_ids.mData.data());
max_token_id_dev.ToDevice(max_token_id.mData.data());
a0_device_buf.ToDevice(a0_t_k_k.mData.data());
b0_device_buf.ToDevice(b0_e_n_k.mData.data());
a1_device_buf.ToDevice(a_scale_preshuffled.mData.data());
b1_device_buf.ToDevice(b_scale_preshuffled.mData.data());
d2_device_buf.ToDevice(d2_e_n.mData.data());
e_device_buf.ToDevice(e_t_n_device_result.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
// do GEMM
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument = device_op.MakeArgument(
sorted_token_ids_dev.GetDeviceBuffer(),
expert_ids_dev.GetDeviceBuffer(),
max_token_id_dev.GetDeviceBuffer(),
a0_device_buf.GetDeviceBuffer(),
a1_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(),
b1_device_buf.GetDeviceBuffer(),
std::array<const void*, NumDTensor>{nullptr, nullptr, d2_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
tokens,
topk,
sorted_size,
N,
K,
StrideA,
Scale_Stride_AM,
StrideB,
Scale_Stride_BN,
StrideDs,
StrideE,
KBatch,
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"))
{
std::cout << "This kernel support gfx942 and gfx950 only" << std::endl;
}
if(time_kernel)
{
// not result correct here because output buf not setzero
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
// FMA * tokens * N * topk * K +
// FMA * tokens * N * topk * (K/BlockScale)
std::size_t flop = std::size_t(2) * tokens * topk * N * K +
std::size_t(2) * tokens * topk * N * K / ScaleBlockSize;
std::size_t num_btype =
sizeof(A0DataType) / 2 * tokens * K * topk + sizeof(B0DataType) / 2 * K * N * experts +
sizeof(XDataType) * tokens * topk * K / ScaleBlockSize +
sizeof(XDataType) * K / ScaleBlockSize * N * experts + sizeof(EDataType) * tokens * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << device_op.GetTypeString() << std::endl;
}
if(do_verification)
{
// gemm2 use atomic, so need to reinit outputs
e_device_buf.ToDevice(e_t_n_device_result.mData.data());
invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1});
Tensor<float> c_t_n({tokens, N});
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceMoeMXGemm2<A0DataType,
XDataType,
B0DataType,
XDataType,
D2DataType,
float, // using float for Cshuffle type
// in reference
AccDataType,
PassThrough,
PassThrough,
CDEElementOp,
MulRoutedWeight,
float,
float>;
auto ref_moe_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_moe_gemm.MakeInvoker();
auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids,
expert_ids,
max_token_id,
MPerBlock,
a0_t_k_k,
a1_t_k_k,
b0_e_n_k,
b1_e_n_k,
d2_e_n, // topk weights
c_t_n,
PassThrough{},
PassThrough{},
cde_element_op);
ref_invoker.Run(ref_argument);
for(int t = 0; t < tokens; ++t)
{
for(int n = 0; n < N; ++n)
{
e_t_n_host_result(t, n) = ck::type_convert<EDataType>(c_t_n(t, n));
}
}
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
return ck::utils::check_err(
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2)
? 0
: 1;
}
return 0;
}

View File

@@ -158,7 +158,7 @@ using DeviceOpInstance = ck::tensor_operation::device::Devic
4, 4,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0,
2, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>,
2, 4, S<1, 4, 1, 64>, S<2, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>;
// clang-format on

View File

@@ -0,0 +1,584 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_moe_mx_gemm_bpreshuffle.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_moe_mx_gemm2.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using F4 = ck::f4x2_pk_t;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
using XDataType = ck::e8m0_bexp_t;
using XPackedDataType = int32_t; // 4 packed e8m0_bexp_t
using I64 = int64_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using A0DataType = F4;
using A1DataType = XPackedDataType;
using B0DataType = F4;
using B1DataType = XPackedDataType;
using EDataType = F16;
using AccDataType = F32;
using CShuffleDataType = F16;
using D0DataType = F32;
using D1DataType = F32;
using D2DataType = F32;
using DsDataType = ck::Tuple<D0DataType, D1DataType, D2DataType>;
using A0Layout = Row;
using B0Layout = Col;
using ELayout = Row;
using D0Layout = Row;
using D1Layout = Col;
using D2Layout = ELayout;
using DsLayout = ck::Tuple<D0Layout, D1Layout, D2Layout>;
// d0: ascale, d1: bscale, d2:expert weight
struct MulABScaleExpertWeight
{
template <typename E, typename C, typename D0, typename D1, typename D2>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1, const D2& d2) const;
// for real kernel use
template <>
__host__ __device__ constexpr void operator()<EDataType, F16, float, float, float>(
EDataType& e, const F16& c, const float& d0, const float& d1, const float& d2) const
{
(void)d0;
(void)d1;
(void)d2;
e = ck::type_convert<EDataType>(c);
}
// for reference cpu
template <>
__host__ __device__ constexpr void operator()<float, float, float, float, float>(
float& e, const float& c, const float& d0, const float& d1, const float& d2) const
{
// for reference cpu
e = ck::type_convert<EDataType>(c * d0 * d1 * d2);
}
};
using CDEElementOp = MulABScaleExpertWeight;
// B preshuffle
void preShuffleBuffer(const F4* src, F4* dst, int N, int K, int NXdl)
{
int KPack = 16;
int NLane = NXdl;
int KLane = 64 / NLane;
int K_pk = K / 2;
int K0 = K_pk / (KLane * KPack);
// K -> K0 KLane KPack
// N -> N0 NLane
// N, K -> N0 K0 KLane NLane KPack
I64 tempk;
for(I64 n = 0; n < N; ++n)
{
for(I64 k = 0; k < K_pk; ++k)
{
I64 n0 = n / NLane;
I64 n1 = n % NLane;
I64 k0 = k / (KLane * KPack);
tempk = k % (KLane * KPack);
I64 k1 = tempk / KPack;
I64 k2 = tempk % KPack;
I64 outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
k1 * KPack * NLane + n1 * KPack + k2;
dst[outputIndex] = src[n * K_pk + k];
}
}
}
// A, B Scale preshuffle
template <bool KLast>
void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K)
{
int MNXdlPack = 2;
int KXdlPack = 2;
int XdlMNThread = 16;
int XdlKThread = 64 / XdlMNThread;
int K0 = K / KXdlPack / XdlKThread; // KRepeat
// The 4 16x128 building blocks will be packed into 1 32x256 for F4
// The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4
// unfold the MN32xK(256/32) scale buffer
// 4 16 2 2
// To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack
// Then, MNRepeat->KRepeat
for(int n = 0; n < MN; ++n)
{
for(int k = 0; k < K; ++k)
{
int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat
int tempn = n % (XdlMNThread * MNXdlPack);
int n1 = tempn % XdlMNThread; // i XdlMNThread
int n2 = tempn / XdlMNThread; // i MNXdlPack
int k0 = k / (XdlKThread * KXdlPack); // i KRepeat
int tempk = k % (XdlKThread * KXdlPack);
int k1 = tempk % XdlKThread; // i XdlKThread
int k2 = tempk / XdlKThread; // i KXdlPack
int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 +
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack +
k2 * MNXdlPack + n2;
// src[n * K + k] = ck::type_convert<ck::e8m0_bexp_t>(static_cast<float>(powf(2.0f, n2 +
// k2 * MNXdlPack)));
if constexpr(KLast)
dst[outputIndex] = src[n * K + k];
else
dst[outputIndex] = src[k * MN + n];
}
}
}
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = MulABScaleExpertWeight;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
constexpr ck::index_t DataPackedSize = 2; // Packed representation of data
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
static constexpr ck::index_t MPerBlock = 128;
static constexpr bool MulRoutedWeight = true;
// clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmMXBPreShuffle<
A0Layout, B0Layout, DsLayout, ELayout,
A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType,
AElementOp, BElementOp, CDEElementOp, GemmSpec,
ScaleBlockSize, 256,
MPerBlock, 128, KPerBlock,
16, 16,
16, 16,
8, 2,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1,
2, 2, S<1, 4, 1, 64>, S<2, 1, 1, 1>,
ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, ck::index_t, A0DataType>;
// clang-format on
int main(int argc, char* argv[])
{
bool do_verification = true;
int init_method = 1;
bool time_kernel = true;
// per expert:
// GEMM shape
constexpr ck::index_t sorted_tile_num = 13;
constexpr ck::index_t valid_tile_num = 13;
ck::index_t sorted_size = sorted_tile_num * MPerBlock;
ck::index_t valid_size = valid_tile_num * MPerBlock;
ck::index_t N = 6144;
ck::index_t K = 4096;
ck::index_t experts = 8;
ck::index_t tokens = 832;
ck::index_t topk = 2;
if(argc == 1)
{
// use default case
}
else if(argc == 4)
{
// use default case
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
}
else if(argc == 7)
{
do_verification = std::stoi(argv[1]);
init_method = std::stoi(argv[2]);
time_kernel = std::stoi(argv[3]);
N = std::stoi(argv[4]);
K = std::stoi(argv[5]);
tokens = std::stoi(argv[6]);
}
else
{
printf("arg1: verification (0=no, 1=yes)\n");
printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n");
printf("arg3: time kernel (0=no, 1=yes)\n");
printf("arg4 to 6: N, K, tokens\n");
exit(0);
}
if(K % ScaleBlockSize != 0)
{
throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize.");
};
ck::index_t StrideA = K;
ck::index_t StrideB = K;
ck::index_t StrideE = N;
ck::index_t Scale_Stride_AM = (K + ScaleBlockSize - 1) / ScaleBlockSize;
ck::index_t Scale_Stride_BN = (K + ScaleBlockSize - 1) / ScaleBlockSize;
constexpr ck::index_t NumDTensor = DsDataType::Size();
constexpr auto StrideDs = std::array<ck::index_t, NumDTensor>{0, 0, 0};
ck::index_t KBatch = 1;
Tensor<ck::index_t> expert_ids(HostTensorDescriptor({sorted_tile_num}, {1}));
Tensor<ck::index_t> sorted_token_ids(HostTensorDescriptor({sorted_size}, {1}));
Tensor<ck::index_t> max_token_id(HostTensorDescriptor({1}));
max_token_id.mData[0] = valid_size;
// int eids[] = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 3, 3, 3};
int eids[sorted_tile_num]{};
for(int i = 0; i < sorted_tile_num; i++)
{
if(i < valid_tile_num)
{
eids[i] = (i * experts) / valid_tile_num;
}
else
{
eids[i] = 3;
}
}
for(int i = 0; i < sorted_tile_num; i++)
{
expert_ids.mData[i] = eids[i];
}
if(tokens * topk > valid_size)
{
printf("err config, tokens * topk > valid_size\n");
exit(-1);
}
int token_per_tile = tokens * topk / valid_tile_num;
int tokenid = 0;
for(int i = 0; i < sorted_size; i++)
{
int tile_off = i % MPerBlock;
if(tile_off < token_per_tile)
{
sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24);
tokenid++;
}
else
{
sorted_token_ids.mData[i] = tokens;
}
}
expert_ids.savetxt("expert_ids.txt", "int");
sorted_token_ids.savetxt("sorted_token_ids.txt", "int");
Tensor<A0DataType> a0_t_k_k(HostTensorDescriptor({tokens, topk, K}, {topk * K, K, 1}));
Tensor<XDataType> a1_t_k_k(
HostTensorDescriptor({tokens, topk, (K + ScaleBlockSize - 1) / ScaleBlockSize},
{(topk * Scale_Stride_AM), Scale_Stride_AM, 1}));
Tensor<B0DataType> b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
Tensor<XDataType> b1_e_n_k(
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N},
{(N * Scale_Stride_BN), 1, Scale_Stride_BN}));
// B preshuffle
Tensor<B0DataType> b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K}));
// A, B Scale preshuffle
Tensor<XDataType> a_scale_sorted(HostTensorDescriptor(
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
Tensor<XDataType> a_scale_preshuffled(HostTensorDescriptor(
{sorted_size, (K + ScaleBlockSize - 1) / ScaleBlockSize}, {Scale_Stride_AM, 1}));
Tensor<XDataType> b_scale_preshuffled(
HostTensorDescriptor({experts, (K + ScaleBlockSize - 1) / ScaleBlockSize, N},
{N * Scale_Stride_BN, 1, Scale_Stride_BN}));
Tensor<D2DataType> d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0}));
Tensor<EDataType> e_t_n_host_result(HostTensorDescriptor({tokens, N}, {N, 1}));
Tensor<EDataType> e_t_n_device_result(HostTensorDescriptor({tokens, N}, {N, 1}));
e_t_n_device_result.SetZero();
std::cout << "a0_t_k_k: " << a0_t_k_k.mDesc << std::endl;
std::cout << "a1_t_k_k: " << a1_t_k_k.mDesc << std::endl;
std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl;
std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl;
std::cout << "d2_e_n: " << d2_e_n.mDesc << std::endl;
std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl;
switch(init_method)
{
case 0: break;
case 1:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-1, 1});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-1, 1});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0, 1.0});
break;
case 2:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 3:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 4:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 5.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 5:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 6:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_1<B0DataType>{});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 7:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_1<A0DataType>{});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
case 8:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-2, 2});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_1<XDataType>{});
d2_e_n.GenerateTensorValue(GeneratorTensor_1<D2DataType>{});
break;
default:
a0_t_k_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
b0_e_n_k.GenerateTensorValue(GeneratorTensor_3<B0DataType>{-0.5, 0.5});
a1_t_k_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
b1_e_n_k.GenerateTensorValue(GeneratorTensor_3<XDataType>{0.0, 1.0});
d2_e_n.GenerateTensorValue(GeneratorTensor_3<D2DataType>{0.0, 1.0});
}
DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.GetElementSpaceSize());
DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.GetElementSpaceSize());
DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.GetElementSpaceSize());
DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k_k.GetElementSpaceSize());
DeviceMem a1_device_buf(sizeof(XDataType) * a_scale_sorted.GetElementSpaceSize());
DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.GetElementSpaceSize());
DeviceMem b1_device_buf(sizeof(XDataType) * b1_e_n_k.GetElementSpaceSize());
DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.GetElementSpaceSize());
DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.GetElementSpaceSize());
// A scale sorted
for(int i = 0; i < sorted_size; i++)
{
int token_id = sorted_token_ids.mData[i] & 0x00FFFFFF;
int topk_id = (sorted_token_ids.mData[i] >> 24) & 0x000000FF;
for(int k = 0; k < (K + ScaleBlockSize - 1) / ScaleBlockSize; k++)
{
if(token_id == tokens)
{
a_scale_sorted(i, k) = ck::type_convert<XDataType>(0);
}
else
{
a_scale_sorted(i, k) = a1_t_k_k(token_id, topk_id, k);
}
}
}
// A, B Scale preshuffle
preShuffleScaleBuffer<ck::is_same_v<A0Layout, Row>>(a_scale_sorted.mData.data(),
a_scale_preshuffled.mData.data(),
sorted_size,
K / ScaleBlockSize);
preShuffleScaleBuffer<ck::is_same_v<B0Layout, Col>>(
b1_e_n_k.mData.data(), b_scale_preshuffled.mData.data(), N * experts, K / ScaleBlockSize);
sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data());
expert_ids_dev.ToDevice(expert_ids.mData.data());
max_token_id_dev.ToDevice(max_token_id.mData.data());
a0_device_buf.ToDevice(a0_t_k_k.mData.data());
a1_device_buf.ToDevice(a_scale_preshuffled.mData.data());
b1_device_buf.ToDevice(b_scale_preshuffled.mData.data());
d2_device_buf.ToDevice(d2_e_n.mData.data());
e_device_buf.ToDevice(e_t_n_device_result.mData.data());
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
// do GEMM
auto device_op = DeviceOpInstance{};
preShuffleBuffer(b0_e_n_k.mData.data(),
b0_preshuffled.mData.data(),
N * experts,
K,
device_op.GetPreShuffleParameters());
b0_device_buf.ToDevice(b0_preshuffled.mData.data());
auto invoker = device_op.MakeInvoker();
auto argument = device_op.MakeArgument(
sorted_token_ids_dev.GetDeviceBuffer(),
expert_ids_dev.GetDeviceBuffer(),
max_token_id_dev.GetDeviceBuffer(),
a0_device_buf.GetDeviceBuffer(),
a1_device_buf.GetDeviceBuffer(),
b0_device_buf.GetDeviceBuffer(),
b1_device_buf.GetDeviceBuffer(),
std::array<const void*, NumDTensor>{nullptr, nullptr, d2_device_buf.GetDeviceBuffer()},
e_device_buf.GetDeviceBuffer(),
tokens,
topk,
sorted_size,
N,
K,
StrideA,
Scale_Stride_AM,
StrideB,
Scale_Stride_BN,
StrideDs,
StrideE,
KBatch,
a_element_op,
b_element_op,
cde_element_op);
if(!device_op.IsSupportedArgument(argument))
{
throw std::runtime_error(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem");
}
if(!(ck::get_device_name() == "gfx942" || ck::get_device_name() == "gfx950"))
{
std::cout << "This kernel support gfx942 and gfx950 only" << std::endl;
}
if(time_kernel)
{
// not result correct here because output buf not setzero
float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel});
// FMA * tokens * N * topk * K +
// FMA * tokens * N * topk * (K/BlockScale)
std::size_t flop = std::size_t(2) * tokens * topk * N * K +
std::size_t(2) * tokens * topk * N * K / ScaleBlockSize;
std::size_t num_btype =
sizeof(A0DataType) / 2 * tokens * K * topk + sizeof(B0DataType) / 2 * K * N * experts +
sizeof(XDataType) * tokens * topk * K / ScaleBlockSize +
sizeof(XDataType) * K / ScaleBlockSize * N * experts + sizeof(EDataType) * tokens * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << device_op.GetTypeString() << std::endl;
}
if(do_verification)
{
// gemm2 use atomic, so need to reinit outputs
e_device_buf.ToDevice(e_t_n_device_result.mData.data());
invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1});
Tensor<float> c_t_n({tokens, N});
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceMoeMXGemm2<A0DataType,
XDataType,
B0DataType,
XDataType,
D2DataType,
float, // using float for Cshuffle type
// in reference
AccDataType,
PassThrough,
PassThrough,
CDEElementOp,
MulRoutedWeight,
float,
float>;
auto ref_moe_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_moe_gemm.MakeInvoker();
auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids,
expert_ids,
max_token_id,
MPerBlock,
a0_t_k_k,
a1_t_k_k,
b0_e_n_k,
b1_e_n_k,
d2_e_n, // topk weights
c_t_n,
PassThrough{},
PassThrough{},
cde_element_op);
ref_invoker.Run(ref_argument);
for(int t = 0; t < tokens; ++t)
{
for(int n = 0; n < N; ++n)
{
e_t_n_host_result(t, n) = ck::type_convert<EDataType>(c_t_n(t, n));
}
}
e_device_buf.FromDevice(e_t_n_device_result.mData.data());
return ck::utils::check_err(
e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2)
? 0
: 1;
}
return 0;
}

View File

@@ -128,6 +128,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP)
add_executable(${EXAMPLE_NAME} ${FILE_NAME})
target_link_libraries(${EXAMPLE_NAME} PRIVATE utility)
target_link_libraries(${EXAMPLE_NAME} PRIVATE getopt::getopt)
add_test(NAME ${EXAMPLE_NAME} COMMAND $<TARGET_FILE:${EXAMPLE_NAME}> ${ARGN})
set_property(TARGET ${EXAMPLE_NAME} PROPERTY HIP_ARCHITECTURES ${EX_TARGETS} )
add_dependencies(examples ${EXAMPLE_NAME})

View File

@@ -1,7 +1,9 @@
# validate user-specified fmha_fwd API list
set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv;fwd_decode")
set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv;pagedkv_prefill;fwd_decode")
set(FMHA_FWD_ENABLE_APIS "fwd_decode" CACHE STRING
"semicolon-separated list of APIs to generate (${FMHA_FWD_KNOWN_APIS}) & link, or \"all\".")
# set(FMHA_FWD_ENABLE_APIS "fwd" CACHE STRING
# "semicolon-separated list of APIs to generate (${FMHA_FWD_KNOWN_APIS}) & link, or \"all\".")
if(FMHA_FWD_ENABLE_APIS STREQUAL "all")
set(FMHA_FWD_ENABLE_APIS ${FMHA_FWD_KNOWN_APIS})
endif()
@@ -17,14 +19,30 @@ endforeach()
# list(APPEND FMHA_FWD_ENABLE_APIS "fwd")
# endif()
# Filtering kernel
# set(KERNEL fmha_fwd_decode_d64_bf16_batch_b16x64x32x64x32x64_r1x4x1_r1x4x1_w16x16x32_w16x16x32_decode_qr_vr_psskddv_nlogits_nbias_nmask_nlse_ndropout_nskip_nsquant)
file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS
${CMAKE_CURRENT_LIST_DIR}/generate.py
${CMAKE_CURRENT_LIST_DIR}/codegen/*.py
)
# re-run execute_process `generate.py --list_blobs` if any of the codegen scripts change
set_directory_properties(PROPERTIES CMAKE_CONFIGURE_DEPENDS "${CODE_GEN_SCRIPTS}")
string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}")
set(FMHA_FWD_CODE_GEN_COMMON_ARGS
${CMAKE_CURRENT_LIST_DIR}/generate.py
--api ${FMHA_FWD_APIS}
# --filter fmha_fwd...
)
set(FMHA_BWD_CODE_GEN_COMMON_ARGS
${CMAKE_CURRENT_LIST_DIR}/generate.py
--api bwd
--receipt 3
# --filter fmha_bwd_dot...@fmha_bwd_convert...@fmha_bwd...
)
# generate a list of kernels, but not actually emit files at config sta
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api ${FMHA_FWD_APIS} --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt #--filter ${KERNEL}
COMMAND ${Python3_EXECUTABLE} ${FMHA_FWD_CODE_GEN_COMMON_ARGS}
--list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt
RESULT_VARIABLE ret
)
if(ret AND NOT ret EQUAL 0)
@@ -32,8 +50,8 @@ if(ret AND NOT ret EQUAL 0)
endif()
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api bwd --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt --receipt 3
COMMAND ${Python3_EXECUTABLE} ${FMHA_BWD_CODE_GEN_COMMON_ARGS}
--list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt
RESULT_VARIABLE ret
)
if(ret AND NOT ret EQUAL 0)
@@ -47,14 +65,16 @@ file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS)
add_custom_command(
OUTPUT ${FMHA_FWD_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api ${FMHA_FWD_APIS} --output_dir ${CMAKE_CURRENT_BINARY_DIR} #--filter ${KERNEL}
COMMAND ${Python3_EXECUTABLE} ${FMHA_FWD_CODE_GEN_COMMON_ARGS}
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
DEPENDS ${CODE_GEN_SCRIPTS}
)
add_custom_command(
OUTPUT ${FMHA_BWD_GEN_BLOBS}
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
--api bwd --output_dir ${CMAKE_CURRENT_BINARY_DIR} --receipt 3
COMMAND ${Python3_EXECUTABLE} ${FMHA_BWD_CODE_GEN_COMMON_ARGS}
--output_dir ${CMAKE_CURRENT_BINARY_DIR}
DEPENDS ${CODE_GEN_SCRIPTS}
)
set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd")
@@ -76,7 +96,7 @@ target_sources(${EXAMPLE_FMHA_BWD} PRIVATE ${FMHA_BWD_GEN_BLOBS})
# NOTE: this is dangerous since will change the whole kernel to flush denormals
# WIP with compiler team for an exp2 intrinsic..., then remove this
if(NOT DEFINED FMHA_FWD_FAST_EXP2)
set(FMHA_FWD_FAST_EXP2 true)
set(FMHA_FWD_FAST_EXP2 true)
endif()
set(EXAMPLE_FMHA_FWD_COMPILE_OPTIONS)
@@ -85,9 +105,9 @@ set(EXAMPLE_FMHA_BWD_COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
# ... because they are auto-generated
if(FMHA_FWD_FAST_EXP2)
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
else()
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
endif()
# list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-undefined-func-template -fgpu-flush-denormals-to-zero)
@@ -113,6 +133,13 @@ else()
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_DECODE_API=0)
endif()
# conditionally enable call to the pagedkv_prefill API in fmha_fwd example
if("pagedkv_prefill" IN_LIST FMHA_FWD_ENABLE_APIS)
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=1)
else()
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=0)
endif()
# conditionally specify the use of OCP_FP8
if(CK_USE_OCP_FP8)
list(APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)

View File

@@ -71,6 +71,7 @@ args:
-drop_seed seed for random number generator (default:1)
-drop_offset offset for random number generator (default:0)
-drop_prefs seed and offset values are present on GPU; 0 - host, 1 - device/GPU (default:0)
-num_splits number of splits for key/value. 0 to determine actual number by heuristic (default:1)
-warmup number of iterations before benchmark the kernel (default:5)
-repeat number of iterations to benchmark the kernel (default:20)
```

View File

@@ -123,6 +123,7 @@ PIPELINE_ENUM_MAP = {
# "qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
# "qr_nwarp_sshuffle" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
# "qs" : "ck_tile::BlockFmhaPipelineEnum::QSKSVS",
# "qr_pagedkv" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"decode_qr" : "ck_tile::BlockFmhaPipelineEnum::DECODE_QRKSVS",
}

View File

@@ -150,14 +150,14 @@ unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seq
float fmha_batch_prefill(fmha_batch_prefill_traits t, fmha_batch_prefill_args a, const ck_tile::stream_config& s) {{
float r = -1;
const float min_cu_util_rate = 0.8; // minimum CU utilization rate
[[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate
unsigned num_cus;
if (!get_num_cus(num_cus)) {{
return r;
}}
auto get_num_blocks = [&](unsigned kM0) {{
[[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{
return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0);
}};
@@ -490,7 +490,7 @@ class KernelComponentFactory:
def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16':
return {
'128' : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
128 : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
}
else:
return None
@@ -516,13 +516,11 @@ class KernelComponentFactory:
class CustomFactory(KernelComponentFactory):
@staticmethod
def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]:
result = KernelComponentFactory.get_hdim_tile_size_dict(dtype)
if dtype == 'fp16' or dtype == 'bf16':
return {
'128' : [FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint('get_num_blocks(128) < num_cus * min_cu_util_rate')),
FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),]
}
else:
return None
if 128 in result.keys():
result[128].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint('get_num_blocks(128) < num_cus * min_cu_util_rate')))
return result
def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
@@ -536,9 +534,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
if d == None:
continue
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
tiles = d[hdim_str]
hdim = int(hdim_str)
for (hdim, tiles), mode in itertools.product(d.items(), MODE_MAP.keys()):
for tile, pipeline in itertools.product(tiles, CustomFactory.get_pipelines(dtype, hdim, receipt, mask_impl)):
if mode == "group":
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':

View File

@@ -169,7 +169,7 @@ template <typename dot_do_o_trait_, typename dq_dk_dv_trait_, typename convert_d
float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
if(s.log_level_ > 0)
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << "@" << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << "@" << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << std::flush;
return ck_tile::launch_kernel(s,
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); }},
@@ -527,6 +527,7 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
cond &= bias in ['no', 'bias']
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
cond &= dpad == dvpad
cond &= mode == 'batch'
cond &= deterministic == "f"
if not cond:
continue

View File

@@ -3,9 +3,10 @@
# generate kernel instances to speed up compilation
import copy
from dataclasses import dataclass
from dataclasses import dataclass, field
import fnmatch
import itertools
import os
from pathlib import Path
from typing import List, Optional, Tuple
@@ -114,8 +115,52 @@ float fmha_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
FMHA_FWD_API_FILENAME="fmha_fwd_api.cpp"
FMHA_FWD_API="""
#include <cstdio>
#include <hip/hip_runtime.h>
namespace {{
bool get_num_cus(unsigned& num_cus) {{
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess) {{
fprintf(stderr, "failed to get device");
return false;
}}
hipDeviceProp_t props{{}};
status = hipGetDeviceProperties(&props, device);
if(status != hipSuccess) {{
fprintf(stderr, "failed to get device properties");
return false;
}}
num_cus = props.multiProcessorCount;
return true;
}}
unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{
const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1
return batch * nheads * num_m_blocks * num_n_blocks;
}}
}} // namespace
float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{
float r = -1;
[[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate
unsigned num_cus;
if (!get_num_cus(num_cus)) {{
return r;
}}
[[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{
return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0);
}};
{F_dispatch}
return r;
}}
@@ -131,37 +176,51 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
"""
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
return fmha_fwd_<trait_>(s, a);
}}
"""
@dataclass
class CppConstraint:
bool_expr: str = None
def __str__(self):
if self.bool_expr is None:
return 'true'
else:
return f'{self.bool_expr}'
def __and__(self, other):
return CppConstraint(f'({str(self)}) && ({str(other)})')
@dataclass
class FmhaFwdApiTrait:
pipeline_tag : str
# sync with fmha_fwd_traits<>, to generate fallback calls
hdim : str
dtype : str # data type
mode : str # value from MODE_MAP
bm0 : int # tile size along q seqlen (block size)
bn0 : int # tile size along qk seqlen
bk0 : int # tile size along qk gemm unroll
bn1 : int # tile size along v head_dim
bk1 : int # tile size along kv gemm unroll
bk0max : int
vlayout : str
logits : str
mask : str
bias : str #
lse : str #
dropout : str
squant : str #
spad : str
skpad : str
dpad : str
dvpad : str
skip : str
hdim : str
dtype : str # data type
mode : str # value from MODE_MAP
bm0 : int # tile size along q seqlen (block size)
bn0 : int # tile size along qk seqlen
bk0 : int # tile size along qk gemm unroll
bn1 : int # tile size along v head_dim
bk1 : int # tile size along kv gemm unroll
bk0max : int
vlayout : str
logits : str
mask : str
bias : str #
lse : str #
dropout : str
squant : str #
spad : str
skpad : str
dpad : str
dvpad : str
skip : str
constraint : CppConstraint
@property
def name(self) -> str:
@@ -218,18 +277,19 @@ class FmhaFwdApiTrait:
class FmhaFwdPipeline:
tag : str
F_vlayout : str # row/col
F_spad : str # true/false
F_skpad : str #
F_dpad : str #
F_dvpad : str #
F_logits : str # t/f
F_bias : str # true/false
F_lse : str #
F_dropout : str #
F_squant : str #
F_mask : str # value from MASK_MAP
F_skip : str # true/false
F_vlayout : str # row/col
F_spad : str # true/false
F_skpad : str #
F_dpad : str #
F_dvpad : str #
F_logits : str # t/f
F_bias : str # true/false
F_lse : str #
F_dropout : str #
F_squant : str #
F_mask : str # value from MASK_MAP
F_skip : str # true/false
F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint())
@property
def name(self) -> str:
@@ -303,6 +363,7 @@ class FmhaFwdApiPool:
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_skip=BOOL_MAP[trait.skip],
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_constraint=trait.constraint,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
@@ -317,25 +378,27 @@ class FmhaFwdApiPool:
@dataclass
class FmhaFwdTileSize:
F_bm0 : int # tile size along q seqlen (block size)
F_bn0 : int # tile size along k seqlen
F_bk0 : int # tile size along qk gemm unroll
F_bn1 : int # tile size along v head_dim
F_bk1 : int # tile size along kv gemm unroll
F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
F_rm0 : int # number of warps for gemm0 along q seqlen
F_rn0 : int # number of warps for gemm0 along k seqlen
F_rk0 : int # number of warps for gemm0 along head dim q (not used)
F_rm1 : int # number of warps for gemm1 along q seqlen
F_rn1 : int # number of warps for gemm1 along head dim v
F_rk1 : int # number of warps for gemm1 along k seqlen (not used)
F_wm0 : int # gemm0 warp size along m
F_wn0 : int # gemm0 warp size along n
F_wk0 : int # gemm0 warp size along k
F_wm1 : int # gemm1 warp size along m
F_wn1 : int # gemm1 warp size along n
F_wk1 : int # gemm1 warp size along k
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
F_bm0 : int # tile size along q seqlen (block size)
F_bn0 : int # tile size along k seqlen
F_bk0 : int # tile size along qk gemm unroll
F_bn1 : int # tile size along v head_dim
F_bk1 : int # tile size along kv gemm unroll
F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
F_rm0 : int # number of warps for gemm0 along q seqlen
F_rn0 : int # number of warps for gemm0 along k seqlen
F_rk0 : int # number of warps for gemm0 along head dim q (not used)
F_rm1 : int # number of warps for gemm1 along q seqlen
F_rn1 : int # number of warps for gemm1 along head dim v
F_rk1 : int # number of warps for gemm1 along k seqlen (not used)
F_wm0 : int # gemm0 warp size along m
F_wn0 : int # gemm0 warp size along n
F_wk0 : int # gemm0 warp size along k
F_wm1 : int # gemm1 warp size along m
F_wn1 : int # gemm1 warp size along n
F_wk1 : int # gemm1 warp size along k
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint())
@property
def name(self) -> str:
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\
@@ -429,35 +492,38 @@ class FmhaFwdKernel:
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad,
skip=self.F_pipeline.F_skip)
skip=self.F_pipeline.F_skip,
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint)
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16':
return {
(32, 32) : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
(64, 64) : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
### (96, 128) : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
(128,128) : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
### (160,160) : FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1),
(192,128) : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
### (192,192) : FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1),
(256,256) : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
}
elif dtype == 'fp8' or dtype == 'bf8':
return {
(64,64 ) : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
(128,128) : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
(256,256) : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
}
else:
return None
class KernelComponentFactory:
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
@staticmethod
def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16':
return {
(32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
(64, 64) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
### (96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
(128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
### (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
(192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
### (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
(256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
}
elif dtype == 'fp8' or dtype == 'bf8':
return {
(64,64 ) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
(128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
(256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
}
else:
return None
def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
def get_pipelines(dtype, hdim, hdim_v) -> List[FmhaFwdPipeline]:
@staticmethod
def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]:
# this function will populate a list possible pipelines
# TODO: the order of List matters! the later in this list will be also be checked later
# TODO: currently for qr pipeline, let 't' padding to appear later!!
@@ -502,16 +568,28 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
assert False
return pipelines
class CustomFactory(KernelComponentFactory):
@staticmethod
def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]:
result = KernelComponentFactory.get_hdim_tile_size_dict(dtype)
if dtype == 'fp16' or dtype == 'bf16':
if (128, 128) in result.keys():
result[(128, 128)].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint('get_num_blocks(128) < num_cus * min_cu_util_rate')))
return result
def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
gen = list()
api_pool = FmhaFwdApiPool(mask_impl)
factory = CustomFactory if os.environ.get('CK_TILE_FMHA_FWD_CUSTOM_FACTORY', '0') == '1' else KernelComponentFactory
for dtype in FWD_DTYPE_MAP.keys():
d = get_fmha_fwd_tile_dict_from_dtype(dtype)
d = factory.get_hdim_tile_size_dict(dtype)
if d == None:
continue
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
for ((hdim, hdim_v), tile), mode in itertools.product(d.items(), MODE_MAP.keys()):
for pipeline in get_pipelines(dtype, hdim, hdim_v):
for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), MODE_MAP.keys()):
for tile, pipeline in itertools.product(tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)):
if mode == "group":
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
@@ -551,7 +629,9 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_bias in ['no', 'bias']
cond &= pipeline.F_squant == 'f'
cond &= mode == 'batch'
cond &= pipeline.F_skip == 'f'
cond &= pipeline.F_logits == 'f'
if not cond:
continue
# Aiter(mha_fwd) integration

View File

@@ -332,6 +332,12 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
cond &= pipeline.F_vlayout == 'row'
if not cond:
continue
# PyTorch integration
elif receipt == 4:
cond = dtype in ['fp16, bf16']
cond &= pipeline.F_vlayout == 'row'
if not cond:
continue
api_pool.register_traits(k.api_trait())
gen.append(k)

View File

@@ -34,6 +34,7 @@ K0_MAX_SUBMAX_MAP = {
64 : 64,
96 : 128,
128: 128,
# 160: 160,
256: 256
}
@@ -638,6 +639,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
'64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
### '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
'128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
### '160' : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
'256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
}
elif dtype == 'fp8' or dtype == 'bf8':
@@ -656,6 +658,7 @@ def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[d
'64' : FmhaFwdSplitKVCombineTileSize(32, -1),
### '96' : FmhaFwdSplitKVCombineTileSize(32, -1),
'128' : FmhaFwdSplitKVCombineTileSize(32, -1),
### '160' : FmhaFwdSplitKVCombineTileSize(32, -1),
'256' : FmhaFwdSplitKVCombineTileSize(32, -1),
}
elif dtype == 'fp8' or dtype == 'bf8':
@@ -683,7 +686,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
if dtype in ['fp16', 'bf16']:
for logits, mask, bias, pagedkv in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]):
# TODO: use async pipeline when compiler is more stable
if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128]:
if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128, 160]:
# if True:
pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))
pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))
@@ -751,6 +754,15 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) ->
cond &= pipeline.F_squant == 'f'
if not cond:
continue
# PyTorch integration
elif receipt == 4:
cond = dtype in ['fp16, bf16']
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_bias in ['no', 'bias']
cond &= pipeline.F_squant == 'f'
cond &= mode == 'batch'
if not cond:
continue
# Aiter(mha_varlen_fwd) integration
elif receipt == 200:
cond = dtype in ['fp16', 'bf16']

View File

@@ -0,0 +1,585 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
import copy
from dataclasses import dataclass
import fnmatch
import itertools
from pathlib import Path
from typing import List, Optional, Tuple
from codegen.cmake_config import *
from codegen.cpp_symbol_map import *
DTYPE_BITS = {
"fp32": 32,
"fp16": 16,
"bf16": 16,
"fp8" : 8,
"bf8" : 8
}
K0_MAX_SUBMAX_MAP = {
32 : 32,
64 : 64,
96 : 128,
128: 128,
256: 256
}
FMHA_FWD_PAGEDKV_PIPELINE_MAP = {
"qr_pagedkv" : "ck_tile::BlockFmhaFwdPagedKVPipelineQRKSVS"
}
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n
// auto generated by generate.py
#include "ck_tile/ops/fmha/block/variants.hpp"
#include "fmha_fwd.hpp"
"""
FMHA_FWD_KERNEL_BODY="""
using fmha_dtype_{F_idx} = {F_dtype};
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>,
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
{F_vlayout}>;
using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdPagedKVTraits<{F_spad},
{F_skpad},
{F_dpad},
{F_dvpad},
{F_logits},
{F_bias},
false,
{F_lse}, //lse
{F_pagedkv}, //pagedkv
{F_squant},
{F_occupancy},
{F_skip}>;
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
using fmha_mask_{F_idx} = {F_mask};
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdPagedKVPipelineProblem<
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
fmha_shape_{F_idx},
{F_mode},
fmha_variant_{F_idx},
fmha_mask_{F_idx},
fmha_trait_{F_idx}>;
using fmha_pipeline_{F_idx} = {F_pipeline}<
fmha_pipeline_problem_{F_idx}>;
using fmha_epilogue_{F_idx} =
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
typename FmhaFwdTypeConfig<{F_dtype}>::ODataType,
{F_spad}, {F_dvpad}>>;
using fmha_kernel_{F_idx} =
ck_tile::FmhaFwdPagedKVKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
using trait_{F_idx} = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
#include <iostream>
template<>
float fmha_fwd_pagedkv_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_pagedkv_args a)
{{
using k_ = fmha_kernel_{F_idx};
if(s.log_level_ > 0)
std::cout << ", " << k_::GetName() << std::flush;
auto [kargs, grids] = fmha_fwd_pagedkv_create_kargs_and_grids<k_>(a);
constexpr dim3 blocks = k_::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
return ck_tile::launch_kernel(s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}
"""
FMHA_FWD_API_FILENAME="fmha_fwd_pagedkv_api.cpp"
FMHA_FWD_API="""
float fmha_fwd_pagedkv(fmha_fwd_pagedkv_traits& t, fmha_fwd_pagedkv_args& a, const ck_tile::stream_config& s){{
float r = -1;
{F_dispatch}
return r;
}}
"""
FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
{F_hdim_case}
}}
"""
FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
{F_inner_dispatch}
}}
"""
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) &&
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
return fmha_fwd_pagedkv_<trait_>(s, a);
}}
"""
@dataclass
class FmhaFwdApiTrait:
pipeline_tag : str
# sync with fmha_fwd_traits<>, to generate fallback calls
hdim : str
dtype : str # data type
mode : str # value from MODE_MAP
bm0 : int # tile size along q seqlen (block size)
bn0 : int # tile size along qk seqlen
bk0 : int # tile size along qk gemm unroll
bn1 : int # tile size along v head_dim
bk1 : int # tile size along kv gemm unroll
bk0max : int
vlayout : str
logits : str
mask : str
bias : str #
lse : str #
pagedkv : str
squant : str #
spad : str
skpad : str
dpad : str
dvpad : str
skip : str
@property
def name(self) -> str:
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\
f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.pagedkv}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}'
@property
def scheck(self) -> str:
if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true
if self.pipeline_tag == 'qr_async':
if self.spad == 't' : return 'true' # always support
else : return 'true'
elif self.pipeline_tag in ['qr_pagedkv', 'qs']:
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.seqlen_q % {self.bm0} == 0'
else: assert False
@property
def skcheck(self) -> str:
if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true
if self.pipeline_tag == 'qr_async':
if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0'
else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0'
elif self.pipeline_tag in ['qr_pagedkv', 'qs']:
if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.seqlen_k % {self.bn0} == 0'
else: assert False
@property
def dcheck(self) -> str:
if self.pipeline_tag == 'qr_async':
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dpad == 't': return f'a.hdim_q % {vec} == 0'
else : assert False
elif self.pipeline_tag in ['qr_pagedkv', 'qs']:
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_q % {bk0submax} == 0'
else: assert False
@property
def dvcheck(self) -> str:
if self.pipeline_tag == 'qr_async':
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dvpad == 't': return f'a.hdim_v % {vec} == 0'
else : assert False
elif self.pipeline_tag in ['qr_pagedkv', 'qs']:
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
else : return f'a.hdim_v % {bk0submax} == 0'
else: assert False
@dataclass
class FmhaFwdPipeline:
tag : str
F_vlayout : str # row/col
F_spad : str # true/false
F_skpad : str #
F_dpad : str #
F_dvpad : str #
F_logits : str # t/f
F_bias : str # true/false
F_lse : str #
F_pagedkv : str #
F_squant : str #
F_mask : str # value from MASK_MAP
F_skip : str # true/false
@property
def name(self) -> str:
def pad_name() -> str:
n = ''
if self.F_spad == 't': n += 's'
if self.F_skpad == 't' : n += 'sk'
if self.F_dpad == 't' : n += 'd'
if self.F_dvpad == 't' : n += 'dv'
if n != '' : n = 'p' + n
return n
pn = pad_name()
n = f'{self.tag}_v{self.F_vlayout[0]}'
if pn != '' : n += f'_{pn}'
else: n += '_npad'
if self.F_logits == 't' : n += '_logits'
else: n += '_nlogits'
if self.F_bias != 'no' : n += f'_{self.F_bias}'
else: n += '_nbias'
if self.F_mask[0:2] == 's_':
if self.F_mask == 's_mask': n += f'_mask'
else: n += '_nmask'
else:
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
else: n += '_nmask'
if self.F_lse == 't' : n += '_lse'
else: n += '_nlse'
if self.F_skip == 't' : n += '_skip'
else: n += '_nskip'
if self.F_squant == 't' : n += '_squant'
else: n += '_nsquant'
if self.F_pagedkv == 't' : n += '_pagedkv'
else: n += '_npagedkv'
return n
class FmhaFwdApiPool:
def __init__(self, mask_impl):
self.pool = dict()
self.mask_impl = mask_impl
def register_traits(self, trait : FmhaFwdApiTrait) -> None:
# TODO: do we need to check duplication?
if trait.dtype not in self.pool.keys():
self.pool[trait.dtype] = dict()
if trait.hdim not in self.pool[trait.dtype].keys():
self.pool[trait.dtype][trait.hdim] = list()
self.pool[trait.dtype][trait.hdim].append(copy.copy(trait))
@property
def api(self) -> str:
per_dtypes=str()
for i, dtype in enumerate(self.pool.keys()):
per_hdim_case=str()
for j, hdim in enumerate(self.pool[dtype].keys()):
traits=self.pool[dtype][hdim]
inners=str()
for k, trait in enumerate(traits):
if_k = 'if' if k == 0 else 'else if'
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
F_lse=BOOL_MAP[trait.lse], F_pagedkv=BOOL_MAP[trait.pagedkv], F_skip=BOOL_MAP[trait.skip],
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
if_j = 'if' if j == 0 else 'else if'
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if'
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
if not per_dtypes:
# empty string we add some ignore to suppress warning in api
per_dtypes += ' (void)t ; (void)s ; (void)a;'
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes)
@dataclass
class FmhaFwdTileSize:
F_bm0 : int # tile size along q seqlen (block size)
F_bn0 : int # tile size along k seqlen
F_bk0 : int # tile size along qk gemm unroll
F_bn1 : int # tile size along v head_dim
F_bk1 : int # tile size along kv gemm unroll
F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
F_rm0 : int # number of warps for gemm0 along q seqlen
F_rn0 : int # number of warps for gemm0 along k seqlen
F_rk0 : int # number of warps for gemm0 along head dim q (not used)
F_rm1 : int # number of warps for gemm1 along q seqlen
F_rn1 : int # number of warps for gemm1 along head dim v
F_rk1 : int # number of warps for gemm1 along k seqlen (not used)
F_wm0 : int # gemm0 warp size along m
F_wn0 : int # gemm0 warp size along n
F_wk0 : int # gemm0 warp size along k
F_wm1 : int # gemm1 warp size along m
F_wn1 : int # gemm1 warp size along n
F_wk1 : int # gemm1 warp size along k
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
@property
def name(self) -> str:
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\
f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\
f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" +\
("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
@dataclass
class FmhaFwdKernel:
F_idx : int # this is not a tunable, but a counter to differentiate symbol
F_hdim : int # hdim
F_dtype : str # data type
F_mode : str # value from MODE_MAP
F_tile : FmhaFwdTileSize
F_pipeline : FmhaFwdPipeline
mask_impl : str
@property
def template(self) -> str:
kernel_body = str()
return FMHA_FWD_KERNEL_HEADER + \
FMHA_FWD_KERNEL_BODY.format(
F_idx = self.F_idx,
F_hdim = self.F_hdim,
F_dtype = FWD_DTYPE_MAP[self.F_dtype],
F_bm0 = self.F_tile.F_bm0,
F_bn0 = self.F_tile.F_bn0,
F_bk0 = self.F_tile.F_bk0,
F_bn1 = self.F_tile.F_bn1,
F_bk1 = self.F_tile.F_bk1,
F_bk0max = self.F_tile.F_bk0max,
F_rm0 = self.F_tile.F_rm0,
F_rn0 = self.F_tile.F_rn0,
F_rk0 = self.F_tile.F_rk0,
F_rm1 = self.F_tile.F_rm1,
F_rn1 = self.F_tile.F_rn1,
F_rk1 = self.F_tile.F_rk1,
F_wm0 = self.F_tile.F_wm0,
F_wn0 = self.F_tile.F_wn0,
F_wk0 = self.F_tile.F_wk0,
F_wm1 = self.F_tile.F_wm1,
F_wn1 = self.F_tile.F_wn1,
F_wk1 = self.F_tile.F_wk1,
F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout],
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
F_logits = BOOL_MAP[self.F_pipeline.F_logits],
F_bias = BIAS_MAP[self.F_pipeline.F_bias],
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
F_pagedkv = BOOL_MAP[self.F_pipeline.F_pagedkv],
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
F_skip = BOOL_MAP[self.F_pipeline.F_skip],
F_occupancy = self.F_tile.F_occupancy,
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
F_mode = MODE_MAP[self.F_mode],
F_pipeline = FMHA_FWD_PAGEDKV_PIPELINE_MAP[self.F_pipeline.tag])
@property
def name(self) -> str:
# TODO: we don't encode idx here
return f"fmha_fwd_pagedkv_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \
self.F_tile.name + '_' + self.F_pipeline.name
@property
def filename(self) -> str:
return self.name + ".cpp"
def api_trait(self) -> FmhaFwdApiTrait:
return FmhaFwdApiTrait(
pipeline_tag=self.F_pipeline.tag,
hdim=str(self.F_hdim),
dtype=self.F_dtype,
mode=self.F_mode,
bm0=self.F_tile.F_bm0,
bn0=self.F_tile.F_bn0,
bk0=self.F_tile.F_bk0,
bn1=self.F_tile.F_bn1,
bk1=self.F_tile.F_bk1,
bk0max=self.F_tile.F_bk0max,
vlayout=self.F_pipeline.F_vlayout,
mask=self.F_pipeline.F_mask,
logits=self.F_pipeline.F_logits,
bias=self.F_pipeline.F_bias,
lse=self.F_pipeline.F_lse,
pagedkv=self.F_pipeline.F_pagedkv,
squant=self.F_pipeline.F_squant,
spad=self.F_pipeline.F_spad,
skpad=self.F_pipeline.F_skpad,
dpad=self.F_pipeline.F_dpad,
dvpad=self.F_pipeline.F_dvpad,
skip=self.F_pipeline.F_skip)
# TODO: design a more practical way to do it
# this is current supported tile size per hdim
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16':
return {
# '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# '192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
# '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
}
elif dtype == 'fp8' or dtype == 'bf8':
return {
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
}
else:
return None
def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]:
# this function will populate a list possible pipelines
# TODO: the order of List matters! the later in this list will be also be checked later
# TODO: currently for qr_pagedkv pipeline, let 't' padding to appear later!!
# TODO: how to design this more generic?
squant = 't' if dtype == 'fp8' else 'f'
pipelines = []
if dtype in ['fp16', 'bf16']:
for logits, mask, bias, pagedkv, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'col', 't', 'f', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip))
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'col', 't', 't', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip))
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 'f', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip))
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 't', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip))
elif dtype in ['fp8', 'bf8']:
# TODO
None
elif dtype in ['fp8fp16', 'fp8bf16']:
# TODO
None
else:
assert False
return pipelines
gen = list()
api_pool = FmhaFwdApiPool(mask_impl)
for dtype in FWD_DTYPE_MAP.keys():
d = get_fmha_fwd_tile_dict_from_dtype(dtype)
if d == None:
continue
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
tile = d[hdim_str]
hdim = int(hdim_str)
for pipeline in get_pipelines(dtype, hdim):
# if pipeline.F_pagedkv == 'f':
# continue
if mode == "group":
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
continue
if hdim == 192 and tile.F_bn1 == 128:
# NOTE: this is used to speedup deepseek prefill case, we don't gen training
if pipeline.F_bias != 'no' or pipeline.F_lse == 't' :
continue
# logits_soft_cap is only allowed if no bias
if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'):
continue
k = FmhaFwdKernel(F_idx=0,
F_hdim=hdim,
F_dtype=dtype,
F_mode=mode,
F_tile=tile,
F_pipeline=pipeline,
mask_impl=mask_impl)
if kernel_filter != '':
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
if optdim_list != [-1]:
if hdim not in optdim_list:
continue
# 2 - Flash attention integration
if receipt in (2, 3):
cond = dtype in ['fp16', 'bf16']
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_bias in ['no', 'alibi']
cond &= pipeline.F_squant == 'f'
cond &= pipeline.F_skip == 'f'
if not cond:
continue
# PyTorch integration
elif receipt == 4:
cond = dtype in ['fp16', 'bf16']
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_bias in ['no', 'bias']
cond &= pipeline.F_squant == 'f'
cond &= pipeline.F_skip == 'f'
if not cond:
continue
# Aiter(mha_fwd) integration
elif receipt == 100:
cond = dtype in ['fp16', 'bf16']
cond &= mode == 'batch'
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_squant == 'f'
if not cond:
continue
# Aiter(mha_varlen_fwd) integration
elif receipt == 200:
cond = dtype in ['fp16', 'bf16']
cond &= mode == 'group'
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_squant == 'f'
if not cond:
continue
# aiter::mha_fwd C++ api integration
elif receipt == 600:
cond = dtype in ['fp16', 'bf16']
cond &= pipeline.F_vlayout == 'row'
cond &= pipeline.F_squant == 'f'
if not cond:
continue
api_pool.register_traits(k.api_trait())
gen.append(k)
return (api_pool, gen)
def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
(autogen_dir / kernel.filename).write_text(kernel.template)
def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None:
(autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api)
def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None:
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
for kernel in kernels:
write_single_fwd_kernel(kernel, output_dir)
write_fwd_api(api_pool, output_dir)
def list_blobs(file_path : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None:
with file_path.open('a') as f:
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "fmha_bwd.hpp"
#include "ck_tile/host.hpp"
@@ -355,7 +355,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(bias.type == bias_enum::alibi)
{
auto slopes = ck_tile::get_alibi_slopes<AccDataType>(nhead);
assert(slopes.size() == nhead);
assert(slopes.size() == static_cast<decltype(slopes.size())>(nhead));
if(bias.rank_info == 0)
{
// alibi in 1*h
@@ -756,22 +756,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(p_drop > 0)
{
p_hp_host_ref.ForEach(
[&](auto& self, auto idx) { p_dropped_hp_host_ref(idx) = self(idx); });
p_dropped_hp_host_ref = p_hp_host_ref;
randval_host_ref.ForEach([&](auto& self, auto idx) {
self(idx) = randval_host(b, idx[0], idx[1] + query_offset, idx[2]);
});
ck_tile::reference_batched_dropout(
p_dropped_hp_host_ref, randval_host_ref, p_undrop_in_uint8_t, rp_undrop);
p_dropped_hp_host_ref.ForEach([&](auto& self, auto idx) {
p_lp_host_ref(idx) = ck_tile::type_convert<GemmDataType>(self(idx));
});
p_lp_host_ref = p_dropped_hp_host_ref.template CopyAsType<GemmDataType>();
}
else
{
p_hp_host_ref.ForEach([&](auto& self, auto idx) {
p_lp_host_ref(idx) = ck_tile::type_convert<GemmDataType>(self(idx));
});
p_lp_host_ref = p_hp_host_ref.template CopyAsType<GemmDataType>();
}
// O = P * V
@@ -854,29 +849,27 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
// dS_i_j = P_i_j .* (dP_i_j - dO_i dot O_i)
ds_hp_host_ref.ForEach([&](auto& self, auto idx_gmn) {
AccDataType do_dot_o = 0;
for(int o = 0; o < hdim_v; o++)
{
auto idx_gmo = idx_gmn;
idx_gmo[2] = o;
do_dot_o += ck_tile::type_convert<AccDataType>(do_host_ref(idx_gmo)) *
ck_tile::type_convert<AccDataType>(o_host_refs[wb](idx_gmo));
}
self(idx_gmn) = ck_tile::type_convert<AccDataType>(
p_hp_host_refs[wb](idx_gmn) * (dp_hp_host_ref(idx_gmn) - do_dot_o));
});
ck_tile::make_ParallelTensorFunctor(
[&](auto i0, auto i1, auto i2) {
AccDataType do_dot_o = 0;
for(int o = 0; o < hdim_v; o++)
{
do_dot_o += ck_tile::type_convert<AccDataType>(do_host_ref(i0, i1, o)) *
ck_tile::type_convert<AccDataType>(o_host_refs[wb](i0, i1, o));
}
ds_hp_host_ref(i0, i1, i2) = ck_tile::type_convert<AccDataType>(
p_hp_host_refs[wb](i0, i1, i2) * (dp_hp_host_ref(i0, i1, i2) - do_dot_o));
},
ds_hp_host_ref.mDesc.get_lengths()[0],
ds_hp_host_ref.mDesc.get_lengths()[1],
ds_hp_host_ref.mDesc.get_lengths()[2])(std::thread::hardware_concurrency());
if(use_dbias)
{
ds_hp_host_ref.ForEach([&](auto& self, auto idx) {
dbias_host_ref(idx) = ck_tile::type_convert<BiasGradDataType>(self(idx));
});
dbias_host_ref = ds_hp_host_ref.template CopyAsType<BiasGradDataType>();
}
ds_hp_host_ref.ForEach([&](auto& self, auto idx) {
ds_lp_host_ref(idx) = ck_tile::type_convert<GemmDataType>(self(idx));
});
ds_lp_host_ref = ds_hp_host_ref.template CopyAsType<GemmDataType>();
// dV = P_drop^T@dO^T
// dV = P^T@dO^T w/o dropout

89
example/ck_tile/01_fmha/fmha_fwd.cpp Normal file → Executable file
View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "fmha_fwd.hpp"
#include "ck_tile/host.hpp"
@@ -178,50 +178,30 @@ auto get_elimit<FmhaFwdFp8>(std::string init_method)
}
}
int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int num_n_blocks, int max_splits)
int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int max_splits)
{
// If we have enough to almost fill the SMs, then just use 1 split
if(batch_nhead_mblocks >= 0.8f * num_SMs)
{
return 1;
}
max_splits = std::min({max_splits, num_SMs, num_n_blocks});
max_splits = std::min({max_splits, num_SMs});
float max_efficiency = 0.f;
std::vector<float> efficiency;
efficiency.reserve(max_splits);
auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
// Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
// we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
// (i.e. it's 11 splits anyway).
// So we check if the number of blocks per split is the same as the previous num_splits.
auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {
return num_splits == 1 ||
ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1);
};
for(int num_splits = 1; num_splits <= max_splits; num_splits++)
{
if(!is_split_eligible(num_splits))
float n_waves = float(batch_nhead_mblocks * num_splits) / num_SMs;
float eff = n_waves / ceil(n_waves);
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
if(eff > max_efficiency)
{
efficiency.push_back(0.f);
}
else
{
float n_waves = float(batch_nhead_mblocks * num_splits) / num_SMs;
float eff = n_waves / ceil(n_waves);
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
if(eff > max_efficiency)
{
max_efficiency = eff;
}
efficiency.push_back(eff);
max_efficiency = eff;
}
efficiency.push_back(eff);
}
for(int num_splits = 1; num_splits <= max_splits; num_splits++)
{
if(!is_split_eligible(num_splits))
{
continue;
}
if(efficiency[num_splits - 1] >= 0.85 * max_efficiency)
{
// printf("num_splits chosen = %d\n", num_splits);
@@ -234,6 +214,7 @@ int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int num_n_blocks,
int override_num_splits_if_necessary(
int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits)
{
(void)hdim_v;
int device;
auto status = hipGetDevice(&device);
if(status != hipSuccess)
@@ -250,15 +231,13 @@ int override_num_splits_if_necessary(
// tile size should match the generate.py
const int kM0 = 64;
const int kN1 = hdim_v;
const int num_m_blocks = ck_tile::integer_divide_ceil(max_seqlen_q, kM0);
const int num_n_blocks = ck_tile::integer_divide_ceil(hdim_v, kN1);
if(num_splits < 1 && p_drop == 0.0f)
{
return num_splits_heuristic(
batch * nhead * num_m_blocks, props.multiProcessorCount * 2, num_n_blocks, 128);
batch * nhead * num_m_blocks, props.multiProcessorCount * 2, 128);
}
return num_splits;
@@ -344,7 +323,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
ck_tile::index_t page_block_size = arg_parser.get_int("page_block_size");
#if !CK_TILE_FMHA_FWD_APPENDKV_API && !CK_TILE_FMHA_FWD_SPLITKV_API
#if(!(CK_TILE_FMHA_FWD_APPENDKV_API || CK_TILE_FMHA_FWD_SPLITKV_API || \
CK_TILE_FMHA_FWD_PAGEDKV_API))
if(0 < page_block_size)
{
std::cerr << "paged-kvcache is not supported. ignoring the 'page_block_size' option"
@@ -360,7 +340,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
}
bool use_cache_batch_idx = arg_parser.get_bool("cache_batch_idx");
#if !CK_TILE_FMHA_FWD_APPENDKV_API && !CK_TILE_FMHA_FWD_SPLITKV_API
#if !(CK_TILE_FMHA_FWD_APPENDKV_API || CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API)
if(use_cache_batch_idx)
{
std::cerr << "split-kv is not supported. ignoring the 'cache_batch_idx' option"
@@ -542,8 +522,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
max_seqlen_k = real_seqlen_k;
}
flop += nhead * (static_cast<std::size_t>(2) * real_seqlen_q * real_seqlen_k * hdim_q +
static_cast<std::size_t>(2) * real_seqlen_q * hdim_v * real_seqlen_k);
flop += nhead * (static_cast<std::size_t>(2) * mask.get_unmaskarea() * hdim_q +
static_cast<std::size_t>(2) * mask.get_unmaskarea() * hdim_v);
num_byte += nhead * (sizeof(QDataType) * real_seqlen_q * hdim_q +
sizeof(KDataType) * real_seqlen_k * hdim_q +
@@ -568,7 +548,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
std::cerr << "num_splits greater than 128 is not supported" << std::endl;
return false;
}
#if CK_TILE_FMHA_FWD_SPLITKV_API
#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API
if(0 < p_drop && (1 < num_splits || use_kvcache))
{
std::cerr << "dropout is not supoprted by split-kv kernels. ignoring the 'p_drop' option"
@@ -832,7 +812,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
<< (is_rotary_interleaved ? "inter" : "half") << ")";
}
#endif
#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_DECODE_API
#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API || CK_TILE_FMHA_FWD_DECODE_API
if(1 < num_splits)
{
std::cout << ", num_splits:" << num_splits;
@@ -873,6 +853,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
{
traits.has_dropout = (p_drop > 0.0f);
}
else if constexpr(std::is_same_v<fmha_fwd_pagedkv_traits,
std::decay_t<decltype(traits)>>)
{
traits.use_pagedkv = use_kvcache;
}
}
};
@@ -1081,6 +1066,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
args.split_stride_lse_acc = split_stride_lse_acc;
args.split_stride_o_acc = split_stride_o_acc;
}
else if constexpr(std::is_same_v<fmha_fwd_pagedkv_args, std::decay_t<decltype(args)>>)
{
args.block_table_ptr =
(0 < page_block_size ? block_table_buf.GetDeviceBuffer() : nullptr);
args.batch_stride_block_table = batch_stride_block_table;
args.page_block_size = page_block_size;
args.is_gappy = false; // use 'false' for flash-attention integration
args.cache_batch_idx =
(use_cache_batch_idx ? cache_batch_idx_buf.GetDeviceBuffer() : nullptr);
}
}
};
@@ -1102,7 +1098,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const float fwd_ave_time = [&] {
#if CK_TILE_FMHA_FWD_SPLITKV_API
if(1 < num_splits || use_kvcache)
if(1 < num_splits && use_kvcache)
{
fmha_fwd_splitkv_traits fmha_splitkv_traits;
init_traits(fmha_splitkv_traits);
@@ -1120,6 +1116,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
init_args(fmha_decode_args);
return fmha_fwd_decode(fmha_decode_traits, fmha_decode_args, stream_config);
#elif CK_TILE_FMHA_FWD_PAGEDKV_API
if(use_kvcache)
{
fmha_fwd_pagedkv_traits fmha_pagedkv_traits;
init_traits(fmha_pagedkv_traits);
fmha_fwd_pagedkv_args fmha_pagedkv_args;
init_args(fmha_pagedkv_args);
return fmha_fwd_pagedkv(fmha_pagedkv_traits, fmha_pagedkv_args, stream_config);
}
#else
fmha_fwd_traits fmha_traits;
init_traits(fmha_traits);
@@ -1276,7 +1283,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
q_host_ref.ForEach([&](auto& self, auto i) { self(i) = q_host_ref_ro(i); });
}
#endif
#if CK_TILE_FMHA_FWD_SPLITKV_API
#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API
if(0 < page_block_size) {
if(i_perm) {
k_host_ref.ForEach([&](auto& self, auto i) {
@@ -1327,7 +1334,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
});
}
#endif
#if CK_TILE_FMHA_FWD_SPLITKV_API
#if CK_TILE_FMHA_FWD_SPLITKV_API || CK_TILE_FMHA_FWD_PAGEDKV_API
if(0 < page_block_size) {
if(is_v_rowmajor) {
if(i_perm) {

View File

@@ -178,6 +178,86 @@ struct fmha_fwd_args
drop_seed_offset;
};
struct fmha_fwd_pagedkv_args
{
const void* q_ptr;
const void* k_ptr;
const void* v_ptr;
const void* bias_ptr; // bias or alibi_slope pointer
void* lse_ptr;
void* o_ptr;
void* block_table_ptr;
ck_tile::index_t batch_stride_block_table; // only used if 'block_table_ptr' is not nullptr
ck_tile::index_t page_block_size; // only used if 'block_table_ptr' is not nullptr
bool is_gappy; // differentiate seqstart_k_ptr usage. only used if 'block_table_ptr' is not
// nullptr.
const void* cache_batch_idx;
// the real seqlen_q & seqlen_k are decided by following:
// batch mode: seqlen_q = kargs.seqlen_q
// seqlen_k = kargs.seqlen_k
// group mode: seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b]
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
// or kargs.seqlen_k_ptr[b]
//
// batch mode (kvcache):
// seqlen_q = kargs.seqlen_q
// seqlen_k = kargs.seqlen_k_ptr[b]
// group mode (kvcache):
// seqlen_q = kargs.seqstart_q_ptr[b + 1] - kargs.seqstart_q_ptr[b]
//
// when is_gappy=true:
// seqlen_k = kargs.seqlen_k_ptr[b]
// seqstart_k_ptr[b] now store local offset of each batch
//
// when is_gappy=false:
// seqlen_k = kargs.seqstart_k_ptr[b + 1] - kargs.seqstart_k_ptr[b]
// or kargs.seqlen_k_ptr[b]
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void* seqlen_k_ptr;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t batch;
ck_tile::index_t max_seqlen_q;
ck_tile::index_t hdim_q;
ck_tile::index_t hdim_v;
ck_tile::index_t nhead_q;
ck_tile::index_t nhead_k;
float scale_s;
float scale_p;
float scale_o;
float logits_soft_cap;
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
ck_tile::index_t stride_v;
ck_tile::index_t stride_bias; // if alibi, b*h need set this to h, 1*h need set this to 0
ck_tile::index_t stride_o;
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_bias;
ck_tile::index_t nhead_stride_lse;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_bias;
ck_tile::index_t batch_stride_lse;
ck_tile::index_t batch_stride_o;
ck_tile::index_t window_size_left;
ck_tile::index_t window_size_right;
ck_tile::index_t mask_type;
ck_tile::index_t min_seqlen_q;
};
struct fmha_fwd_splitkv_args
{
const void* q_ptr;
@@ -590,6 +670,114 @@ auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
}
}
template <typename FmhaKernel>
auto fmha_fwd_pagedkv_create_kargs_and_grids(fmha_fwd_pagedkv_args args)
{
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = [&] {
// create group mode kernel arguments
if constexpr(FmhaKernel::kIsGroupMode)
{
return FmhaKernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
args.lse_ptr,
args.o_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_k_ptr,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.block_table_ptr,
args.batch_stride_block_table,
args.page_block_size,
args.is_gappy,
args.scale_s,
args.scale_p,
args.scale_o,
args.logits_soft_cap,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
args.stride_o,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
args.nhead_stride_lse,
args.nhead_stride_o,
args.batch_stride_k,
args.batch_stride_v,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.min_seqlen_q);
}
else
{ // create batch mode kernel arguments
return FmhaKernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
args.lse_ptr,
args.o_ptr,
args.seqlen_q,
args.seqlen_k,
args.seqlen_k_ptr,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.block_table_ptr,
args.batch_stride_block_table,
args.page_block_size,
args.cache_batch_idx,
args.scale_s,
args.scale_p,
args.scale_o,
args.logits_soft_cap,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
args.stride_o,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
args.nhead_stride_lse,
args.nhead_stride_o,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_v,
args.batch_stride_bias,
args.batch_stride_lse,
args.batch_stride_o,
args.window_size_left,
args.window_size_right,
args.mask_type);
}
}();
// FmhaKernel::PrintParameters(kargs, args.batch);
if constexpr(FmhaKernel::kIsGroupMode)
{
dim3 grids = FmhaKernel::GridSize(
args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, args.seqlen_k_ptr != nullptr);
return ck_tile::make_tuple(kargs, grids);
}
else
{
dim3 grids =
FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v, false);
return ck_tile::make_tuple(kargs, grids);
}
}
template <typename Kernel>
auto fmha_fwd_splitkv_create_kargs_and_grids(fmha_fwd_splitkv_args args)
{
@@ -966,102 +1154,102 @@ auto fmha_batch_prefill_create_kargs_and_grids(fmha_batch_prefill_args args)
// create group mode kernel arguments
if constexpr(FmhaKernel::kIsGroupMode)
{
return FmhaKernel::MakeKargsImpl(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
args.rand_val_ptr,
args.lse_ptr,
args.o_ptr,
args.seqstart_q_ptr,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.num_total_pages,
args.kv_indptr,
args.kv_page_indices,
return FmhaKernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
args.rand_val_ptr,
args.lse_ptr,
args.o_ptr,
args.seqstart_q_ptr,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.num_total_pages,
args.kv_indptr,
args.kv_page_indices,
#if 0 // we assume page_block_size=1 for now
args.kv_last_page_lens,
args.page_block_size,
#endif
args.scale_s,
args.scale_p,
args.scale_o,
args.logits_soft_cap,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
args.stride_randval,
args.stride_o,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
args.nhead_stride_randval,
args.nhead_stride_lse,
args.nhead_stride_o,
args.batch_stride_k,
args.batch_stride_v,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
args.scale_s,
args.scale_p,
args.scale_o,
args.logits_soft_cap,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
args.stride_randval,
args.stride_o,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
args.nhead_stride_randval,
args.nhead_stride_lse,
args.nhead_stride_o,
args.batch_stride_k,
args.batch_stride_v,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
}
else
{ // create batch mode kernel arguments
return FmhaKernel::MakeKargsImpl(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
args.rand_val_ptr,
args.lse_ptr,
args.o_ptr,
args.seqlen_q,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.num_total_pages,
args.kv_indptr,
args.kv_page_indices,
return FmhaKernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
args.rand_val_ptr,
args.lse_ptr,
args.o_ptr,
args.seqlen_q,
args.hdim_q,
args.hdim_v,
args.nhead_q,
args.nhead_q / args.nhead_k,
args.num_total_pages,
args.kv_indptr,
args.kv_page_indices,
#if 0 // we assume page_block_size=1 for now
args.kv_last_page_lens,
args.page_block_size,
#endif
args.scale_s,
args.scale_p,
args.scale_o,
args.logits_soft_cap,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
args.stride_randval,
args.stride_o,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
args.nhead_stride_randval,
args.nhead_stride_lse,
args.nhead_stride_o,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_v,
args.batch_stride_bias,
args.batch_stride_randval,
args.batch_stride_lse,
args.batch_stride_o,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
args.scale_s,
args.scale_p,
args.scale_o,
args.logits_soft_cap,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
args.stride_randval,
args.stride_o,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
args.nhead_stride_randval,
args.nhead_stride_lse,
args.nhead_stride_o,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_v,
args.batch_stride_bias,
args.batch_stride_randval,
args.batch_stride_lse,
args.batch_stride_o,
args.window_size_left,
args.window_size_right,
args.mask_type,
args.p_drop,
args.s_randval,
args.drop_seed_offset);
}
}();
@@ -1121,6 +1309,57 @@ struct fmha_fwd_traits_
template <typename Traits_>
float fmha_fwd_(const ck_tile::stream_config&, fmha_fwd_args);
template <ck_tile::index_t HDim_,
typename DataType_,
bool kIsGroupMode_,
ck_tile::index_t kM0_,
ck_tile::index_t kN0_,
ck_tile::index_t kK0_,
ck_tile::index_t kN1_,
ck_tile::index_t kK1_,
ck_tile::index_t kK0BlockLength_,
bool kIsVLayoutRowMajor_,
ck_tile::BlockFmhaPipelineEnum FmhaPipelineEnum_,
bool kHasLogitsSoftCap_,
typename FmhaMask_,
ck_tile::BlockAttentionBiasEnum BiasEnum_,
bool kStoreLse_,
bool kIsPagedKV_,
bool kDoFp8StaticQuant_,
bool kPadS_,
bool kPadSK_,
bool kPadD_,
bool kPadDv_,
bool kSkipMinSeqlenQ_ = false>
struct fmha_fwd_pagedkv_traits_
{
static constexpr ck_tile::index_t HDim = HDim_;
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr ck_tile::index_t kM0 = kM0_;
static constexpr ck_tile::index_t kN0 = kN0_;
static constexpr ck_tile::index_t kK0 = kK0_;
static constexpr ck_tile::index_t kN1 = kN1_;
static constexpr ck_tile::index_t kK1 = kK1_;
static constexpr ck_tile::index_t kK0BlockLength = kK0BlockLength_;
static constexpr bool kIsVLayoutRowMajor = kIsVLayoutRowMajor_;
static constexpr auto FmhaPipelineEnum = FmhaPipelineEnum_;
static constexpr bool kHasLogitsSoftCap = kHasLogitsSoftCap_;
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
static constexpr auto BiasEnum = BiasEnum_;
static constexpr bool kStoreLse = kStoreLse_;
static constexpr bool kIsPagedKV = kIsPagedKV_;
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr bool kPadS = kPadS_;
static constexpr bool kPadSK = kPadSK_;
static constexpr bool kPadD = kPadD_;
static constexpr bool kPadDv = kPadDv_;
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
};
template <typename Traits_>
float fmha_fwd_pagedkv_(const ck_tile::stream_config&, fmha_fwd_pagedkv_args);
template <ck_tile::index_t HDim_,
typename DataType_,
bool kIsGroupMode_,
@@ -1333,6 +1572,27 @@ struct fmha_fwd_traits
};
float fmha_fwd(fmha_fwd_traits, fmha_fwd_args, const ck_tile::stream_config&);
struct fmha_fwd_pagedkv_traits
{
int hdim_q;
int hdim_v;
std::string data_type;
bool is_group_mode;
bool is_v_rowmajor;
bool has_logits_soft_cap;
mask_enum mask_type;
bias_enum bias_type; // 0:no bias, 1:elementwise bias, 2:alibi. sync with BlockAttentionBiasEnum
bool has_lse = false;
bool use_pagedkv = true;
bool do_fp8_static_quant = false;
bool skip_min_seqlen_q = false;
// TODO: padding check is inside this api
};
float fmha_fwd_pagedkv(fmha_fwd_pagedkv_traits&,
fmha_fwd_pagedkv_args&,
const ck_tile::stream_config&);
struct fmha_fwd_splitkv_traits
{
int hdim_q;

21
example/ck_tile/01_fmha/mask.hpp Normal file → Executable file
View File

@@ -21,6 +21,8 @@ enum class mask_enum
struct mask_info
{
mask_enum type;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t y, x;
ck_tile::index_t left, right; // FA style SWA left/right
@@ -42,6 +44,8 @@ struct mask_info
ck_tile::index_t x_total = seqlen_k;
ck_tile::index_t y_total = seqlen_q;
mask_info tmp;
tmp.seqlen_q = seqlen_q;
tmp.seqlen_k = seqlen_k;
auto found_0 = str.find(':');
if(found_0 != std::string::npos)
{
@@ -148,7 +152,22 @@ struct mask_info
}
return tmp;
}
ck_tile::index_t get_unmaskarea() const
{
if(type == mask_enum::no_mask)
return seqlen_q * seqlen_k;
ck_tile::index_t area = 0;
for(ck_tile::index_t i_y = 0; i_y < seqlen_q; ++i_y)
{
ck_tile::index_t x_start = std::max(-y + i_y + 1, static_cast<ck_tile::index_t>(0));
ck_tile::index_t x_end = std::min(i_y + x, seqlen_k);
if(x_end > x_start)
{
area += (x_end - x_start);
}
}
return area;
}
friend std::ostream& operator<<(std::ostream& os, const mask_info& mi)
{
mi.serialize(os);

View File

@@ -75,22 +75,22 @@ struct layernorm2d_fwd_traits_
using SmoothScaleDataType = ck_tile::remove_cvref_t<SmoothScaleDataType_>;
using YScaleDataType = ck_tile::remove_cvref_t<YScaleDataType_>;
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize;
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0);
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size();
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0);
static constexpr ck_tile::index_t total_warps =
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / WarpSize;
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size();
// num of warps along m
static constexpr ck_tile::index_t BlockWarps_M = []() {
if constexpr(is_warp_per_row)
{
static_assert(WarpSize % ThreadPerBlock_N_ == 0);
return total_warps * (WarpSize / ThreadPerBlock_N_);
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_);
}
else
{
// static_assert(WarpSize % ThreadPerBlock_M_ == 0);
return total_warps / (ThreadPerBlock_N_ / WarpSize);
// static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size());
}
}();
@@ -98,13 +98,13 @@ struct layernorm2d_fwd_traits_
static constexpr ck_tile::index_t BlockWarps_N = []() {
if constexpr(is_warp_per_row)
{
static_assert(WarpSize % ThreadPerBlock_N_ == 0);
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
return 1;
}
else
{
static_assert(ThreadPerBlock_N_ % WarpSize == 0);
return ThreadPerBlock_N_ / WarpSize;
static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0);
return ThreadPerBlock_N_ / ck_tile::get_warp_size();
}
}();

View File

@@ -1,5 +1,6 @@
add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp)
add_executable(tile_example_gemm_universal EXCLUDE_FROM_ALL universal_gemm.cpp)
add_executable(tile_example_gemm_weight_preshuffle EXCLUDE_FROM_ALL gemm_weight_preshuffle.cpp)
set(EXAMPLE_GEMM_COMPILE_OPTIONS)
if(CK_USE_OCP_FP8)
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)

View File

@@ -30,7 +30,7 @@ args:
-stride_c Tensor C stride (default:0)
-v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2)
-e Absolute error tolerance (default:1e-5)
-prec data type. fp16/bf16/fp8/bf8 (default:fp16)
-prec data type. fp16/bf16/fp8/bf8/int8 (default:fp16)
-warmup number of iterations before benchmark the kernel (default:10)
-repeat number of iterations to benchmark the kernel (default:100)
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)

View File

@@ -212,6 +212,11 @@ int run_gemm_example(int argc, char* argv[])
return run_gemm_example_prec_type<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "i8")
{
return run_gemm_example_prec_type<ck_tile::int8_t, ck_tile::int8_t, int32_t>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "pk_int4_t")
{
// TODO: Add support for bhalf_t ADataType

View File

@@ -1,4 +1,3 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
@@ -15,6 +14,40 @@
#define CK_TILE_PIPELINE_MEMORY 2
#define CK_TILE_PIPELINE_COMPUTE_V4 3
#define CK_TILE_PIPELINE_COMPUTE_V5 4
#define CK_TILE_PIPELINE_PRESHUFFLE 5
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
{
#if defined(__gfx950__)
constexpr bool is_8bit_float =
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
if constexpr(M_Warp_Tile == 32)
return is_8bit_float ? 64 : 16;
else
return is_8bit_float ? 128 : 32;
#else
if constexpr(M_Warp_Tile == 32)
return 16;
else
return 32;
#endif
}
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile_flatmm()
{
#if defined(__gfx950__)
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 64;
else
return sizeof(PrecType) == 2 ? 32 : 128;
#else
if constexpr(M_Warp_Tile == 32)
return sizeof(PrecType) == 2 ? 16 : 32;
else
return sizeof(PrecType) == 2 ? 32 : 64;
#endif
}
struct GemmConfigBase
{
@@ -34,6 +67,7 @@ struct GemmConfigBase
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool Preshuffle = false;
};
template <typename PrecType>
@@ -90,7 +124,7 @@ struct GemmConfigComputeV3 : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 16 : 64;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
@@ -109,7 +143,7 @@ struct GemmConfigComputeV3_1 : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 16 : 64;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
@@ -128,7 +162,7 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 32 : 128;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
@@ -151,7 +185,7 @@ struct GemmConfigComputeV4 : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 16 : 64;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
@@ -170,7 +204,7 @@ struct GemmConfigComputeV4_1 : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 16 : 64;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = true;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
@@ -189,13 +223,57 @@ struct GemmConfigComputeV5 : public GemmConfigBase
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 16 : 64;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr bool DoubleSmemBuffer = false;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5;
static constexpr ck_tile::index_t NumWaNumWaveGroups = 2;
};
template <typename PrecType>
struct GemmConfigPreshufle_1 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE;
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = false;
};
template <typename PrecType>
struct GemmConfigPreshufle_2 : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm<PrecType, M_Warp_Tile>();
static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE;
static constexpr bool Preshuffle = true;
static constexpr bool DoubleSmemBuffer = false;
};
template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType>
struct GemmTypeConfig;
@@ -245,6 +323,15 @@ struct GemmTypeConfig<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>
using CDataType = ck_tile::half_t;
};
template <>
struct GemmTypeConfig<ck_tile::int8_t, ck_tile::int8_t, int32_t>
{
using ADataType = ck_tile::int8_t;
using BDataType = ck_tile::int8_t;
using AccDataType = int32_t;
using CDataType = int32_t;
};
template <typename T>
struct DataTypeTraits;
@@ -260,6 +347,12 @@ struct DataTypeTraits<double>
static constexpr const char* name = "fp64";
};
template <>
struct DataTypeTraits<int32_t>
{
static constexpr const char* name = "int32";
};
template <>
struct DataTypeTraits<ck_tile::half_t>
{
@@ -290,6 +383,12 @@ struct DataTypeTraits<ck_tile::pk_int4_t>
static constexpr const char* name = "pk_int4_t";
};
template <>
struct DataTypeTraits<ck_tile::int8_t>
{
static constexpr const char* name = "int8";
};
template <ck_tile::index_t PipelineId>
struct PipelineTypeTraits;
@@ -329,6 +428,16 @@ struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V5>
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5<PipelineProblem>;
};
template <>
struct PipelineTypeTraits<CK_TILE_PIPELINE_PRESHUFFLE>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV1<PipelineProblem>;
template <typename PipelineProblem>
using UniversalGemmPipeline =
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV1<PipelineProblem>;
};
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;

View File

@@ -0,0 +1,294 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <sstream>
#include <string>
#include <tuple>
#include "ck_tile/host.hpp"
#include "gemm_utils.hpp"
#include "run_gemm_example.inc"
template <typename GemmConfig,
typename ADataType,
typename BDataType,
typename DsDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
bool Persistent,
typename CDEElementWise>
float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile::stream_config& s)
{
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
GemmConfig::PermuteA,
GemmConfig::PermuteB>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
ALayout,
BLayout,
ELayout,
GemmConfig::NumWaveGroups>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::DoubleSmemBuffer,
ALayout,
BLayout,
ELayout,
GemmConfig::TransposeC,
GemmConfig::UseStructuredSparsity,
Persistent,
GemmConfig::NumWaveGroups,
GemmConfig::Preshuffle>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run =
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
ELayout,
CDEElementWise,
UniversalGemmProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation,
GemmConfig::NumWaveGroups>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
dim3 grids;
if constexpr(Persistent)
{
grids = Kernel::MaxOccupancyGridSize(s);
}
else
{
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
}
constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << std::endl;
}
if(s.flush_cache_)
{
std::cout << "Flushing cache..." << std::endl;
static constexpr ck_tile::index_t APackedSize =
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
static constexpr ck_tile::index_t BPackedSize =
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck_tile::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if(args.k_batch > 1)
hipGetErrorString(hipMemsetAsync(
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
};
ave_time = ck_tile::launch_kernel_preprocess(
s,
run_flush_cache,
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
}
else
{
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
}
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(args.k_batch == 1)
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
return ave_time;
}
template <typename GemmConfig,
typename APrecType,
typename BPrecType = APrecType,
typename CPrecType = APrecType>
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
auto [result, arg_parser] = create_args(argc, argv);
bool preshuffle = GemmConfig::Preshuffle;
if(preshuffle && a_layout != "R" && b_layout != "C")
{
throw std::runtime_error(
"Preshuffle is supported only for A(Row major), B(column major) input matrices!");
}
if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
argc, argv, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported memory layout for the input matrices!");
}
}
template <template <typename PreType> typename GemmConfig>
int run_gemm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
if(data_type == "fp16")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::half_t>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "bf16")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::bf16_t>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "fp8")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
ck_tile::fp8_t,
ck_tile::fp8_t,
ck_tile::half_t>(a_layout, b_layout, argc, argv);
}
else if(data_type == "bf8")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
ck_tile::bf8_t,
ck_tile::bf8_t,
ck_tile::half_t>(a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported data type for this operation !!!");
}
}
int main(int argc, char* argv[])
{
try
{
return !run_gemm_example<GemmConfigPreshufle_1>(argc, argv);
}
catch(const std::runtime_error& e)
{
std::cerr << "Caught runtime error: " << e.what() << '\n';
// Return a non-zero code to indicate failure
return EXIT_FAILURE;
}
return EXIT_SUCCESS;
}

View File

@@ -251,6 +251,22 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
return ave_time;
}
template <typename GemmConfig, typename T>
auto shuffle_b(const ck_tile::HostTensor<T>& t)
{
assert(t.get_lengths().size() == 2);
int n_ = t.get_lengths()[1];
int k_ = t.get_lengths()[0];
constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
GemmConfig::N_Warp_Tile,
k_ / GemmConfig::K_Warp_Tile,
divisor,
GemmConfig::K_Warp_Tile / divisor});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
}
template <typename GemmConfig,
typename ADataType,
typename BDataType = ADataType,
@@ -284,6 +300,8 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::index_t init_method = arg_parser.get_int("init");
bool persistent = arg_parser.get_int("persistent");
const bool preshuffle = GemmConfig::Preshuffle;
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{}));
@@ -297,8 +315,8 @@ int run_gemm_example_with_layouts(int argc,
if(init_method == 0)
{
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n);
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
}
else if(init_method == 1)
{
@@ -316,7 +334,7 @@ int run_gemm_example_with_layouts(int argc,
b_k_n.SetZero();
}
if(GemmConfig::UseStructuredSparsity)
if(!preshuffle && GemmConfig::UseStructuredSparsity)
{
ck_tile::AdjustToStructuredSparsity<ADataType>{}(a_m_k);
}
@@ -326,33 +344,43 @@ int run_gemm_example_with_layouts(int argc,
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
static_assert(!GemmConfig::PermuteA, "Not implemented");
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
if constexpr(preshuffle)
{
// Permute vector pk_i4x4 data for device implementation
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
if constexpr(GemmConfig::PermuteB)
{
permute_tensor_b<GemmConfig,
decltype(b_k_n_dev),
ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(b_k_n_dev);
}
permute_vectors_i4x4_b(b_k_n_dev);
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_b<GemmConfig>(b_k_n);
// shuffled buffer B for device implementation
b_k_n_dev_buf.ToDevice(b_shuffle_host.data());
}
else
{
if constexpr(GemmConfig::PermuteB)
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
{
std::cout << "Permute for this DataType is not implemented." << std::endl;
return false;
// Permute vector pk_i4x4 data for device implementation
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
if constexpr(GemmConfig::PermuteB)
{
permute_tensor_b<GemmConfig,
decltype(b_k_n_dev),
ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(b_k_n_dev);
}
permute_vectors_i4x4_b(b_k_n_dev);
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
}
else
{
if constexpr(GemmConfig::PermuteB)
{
std::cout << "Permute for this DataType is not implemented." << std::endl;
return false;
}
b_k_n_dev_buf.ToDevice(b_k_n.data());
}
b_k_n_dev_buf.ToDevice(b_k_n.data());
}
a_m_k_dev_buf.ToDevice(a_m_k.data());
@@ -415,29 +443,23 @@ int run_gemm_example_with_layouts(int argc,
// Restore input for B for gpu reference
b_k_n_dev_buf.ToDevice(b_k_n.data());
}
if constexpr(GemmConfig::Preshuffle)
{
b_k_n_dev_buf.ToDevice(b_k_n.data());
}
// memory on host to store gpu reference result
ck_tile::HostTensor<CDataType> c_m_n_gpu_ref(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
// memory on device to store gpu reference result
ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes());
c_m_n_gpu_ref.SetZero();
c_m_n_gpu_buf_ref.SetZero();
ADataType* d_A;
BDataType* d_B;
CDataType* d_C;
ck_tile::hip_check_error(hipMalloc(&d_A, a_m_k.get_element_space_size_in_bytes()));
ck_tile::hip_check_error(hipMalloc(&d_B, b_k_n.get_element_space_size_in_bytes()));
ck_tile::hip_check_error(
hipMalloc(&d_C, c_m_n_dev_result.get_element_space_size_in_bytes()));
ck_tile::hip_check_error(hipMemcpy(d_A,
a_m_k_dev_buf.GetDeviceBuffer(),
a_m_k.get_element_space_size_in_bytes(),
hipMemcpyHostToDevice));
ck_tile::hip_check_error(hipMemcpy(d_B,
b_k_n_dev_buf.GetDeviceBuffer(),
b_k_n.get_element_space_size_in_bytes(),
hipMemcpyHostToDevice));
ADataType* d_A = static_cast<ADataType*>(a_m_k_dev_buf.GetDeviceBuffer());
BDataType* d_B = static_cast<BDataType*>(b_k_n_dev_buf.GetDeviceBuffer());
CDataType* d_C = static_cast<CDataType*>(c_m_n_gpu_buf_ref.GetDeviceBuffer());
ck_tile::reference_gemm_gpu<ADataType,
BDataType,
@@ -447,16 +469,8 @@ int run_gemm_example_with_layouts(int argc,
BLayout,
CLayout>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
ck_tile::hip_check_error(hipMemcpy(c_m_n_gpu_buf_ref.GetDeviceBuffer(),
d_C,
c_m_n_dev_result.get_element_space_size_in_bytes(),
hipMemcpyDeviceToHost));
ck_tile::hip_check_error(hipFree(d_A));
ck_tile::hip_check_error(hipFree(d_B));
ck_tile::hip_check_error(hipFree(d_C));
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
const float max_accumulated_value =
*std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(

View File

@@ -59,7 +59,8 @@ float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile:
GemmConfig::TransposeC,
GemmConfig::UseStructuredSparsity,
Persistent,
GemmConfig::NumWaveGroups>;
GemmConfig::NumWaveGroups,
GemmConfig::Preshuffle>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
@@ -71,7 +72,6 @@ float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile:
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run =
@@ -92,6 +92,7 @@ float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile:
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
@@ -101,7 +102,7 @@ float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile:
DsLayout,
ELayout,
CDEElementWise,
GemmPipelineProblem::kBlockSize,
UniversalGemmProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
@@ -112,6 +113,7 @@ float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile:
UniversalGemmProblem::TransposeC,
memory_operation,
GemmConfig::NumWaveGroups>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
@@ -135,7 +137,7 @@ float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile:
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << GemmShape::GetName() << '\n'
<< "problem: " << GemmPipelineProblem::GetName() << '\n'
<< "problem: " << UniversalGemmProblem::GetName() << '\n'
<< "pipeline: " << GemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
@@ -214,8 +216,21 @@ template <typename GemmConfig,
typename CPrecType = APrecType>
int run_gemm_example_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
auto [result, arg_parser] = create_args(argc, argv);
bool preshuffle = GemmConfig::Preshuffle;
if(preshuffle && std::is_same_v<BPrecType, ck_tile::pk_int4_t>)
{
throw std::runtime_error("Preshuffle is not supported for this int4 datatype!");
}
if(preshuffle && a_layout != "R" && b_layout != "C")
{
throw std::runtime_error(
"Preshuffle is supported only for A(Row major), B(column major) input matrices!");
}
if constexpr(std::is_same_v<BPrecType, ck_tile::pk_int4_t>)
{
@@ -299,6 +314,13 @@ int run_gemm_example(int argc, char* argv[])
ck_tile::bf8_t,
ck_tile::half_t>(a_layout, b_layout, argc, argv);
}
else if(data_type == "int8")
{
return run_gemm_example_prec_type<GemmConfig<ck_tile::int8_t>,
ck_tile::int8_t,
ck_tile::int8_t,
ck_tile::int32_t>(a_layout, b_layout, argc, argv);
}
else if(data_type == "pk_int4_t")
{
// TODO: Add support for bhalf_t ADataType

View File

@@ -35,7 +35,7 @@ struct Reduce2dShape
static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N);
static constexpr index_t BlockSize =
WarpSize * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
};
template <typename XDataType_,

View File

@@ -15,13 +15,14 @@ auto create_args(int argc, char* argv[])
.insert("v", "1", "cpu validation or not")
.insert("prec", "fp16", "precision")
.insert("warmup", "0", "cold iter")
.insert("repeat", "1", "hot iter");
.insert("repeat", "1", "hot iter")
.insert("s", "0", "sensitive model mode, 0: for no specific model, 1: for T5-like model");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename DataType>
template <typename DataType, int USEModelSensitive>
bool run(const ck_tile::ArgParser& arg_parser)
{
ck_tile::index_t m = arg_parser.get_int("m");
@@ -81,8 +82,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
false, // kSaveInvRms
false, // kSaveUnquant
kTwoPass,
ck_tile::Rmsnorm2dFusedAddEnum::NO_ADD, // fuse add
ck_tile::Rmsnorm2dFusedQuantEnum::NO_SWEEP>; // fuse quant
ck_tile::Rmsnorm2dFusedAddEnum::NO_ADD, // fuse add
ck_tile::Rmsnorm2dFusedQuantEnum::NO_SWEEP, // fuse quant
static_cast<ck_tile::Rmsnorm2dSensitiveEnum>(
USEModelSensitive)>;
using Problem = ck_tile::Rmsnorm2dFwdPipelineProblem<XDataType,
GammaDataType,
@@ -97,7 +100,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
using OnePassPipeline = ck_tile::Rmsnorm2dFwdPipelineOnePass<Problem>;
using TwoPassPipeline = ck_tile::Rmsnorm2dFwdPipelineTwoPass<Problem>;
using Pipeline = std::conditional_t<kTwoPass, TwoPassPipeline, OnePassPipeline>;
using T5PassPipeline = ck_tile::Rmsnorm2dFwdPipelineModelSensitiveT5Pass<Problem>;
using Pipeline =
std::conditional_t<(PipelineTraits::kUseModelSensitiveRMSNorm ==
ck_tile::Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL ||
PipelineTraits::kTwoPass), // TODO: consider TwoPass for T5PassPipeline
std::conditional_t<PipelineTraits::kTwoPass,
TwoPassPipeline,
OnePassPipeline>, // kUseModelSensitiveRMSNorm
// == 0
T5PassPipeline>;
using Default2DEpilogueProblem = ck_tile::
Default2DEpilogueProblem<ComputeDataType, YDataType, false, PipelineTraits::kPadN, false>;
@@ -172,7 +185,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
std::cout << "[" << data_type << "]"
<< " m:" << m << ", n:" << n << ", stride:" << stride
<< ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
<< ", s:" << USEModelSensitive << ", valid:" << (pass ? "y" : "n") << std::flush
<< std::endl;
}
return pass;
@@ -184,10 +198,19 @@ int main(int argc, char* argv[])
if(!result)
return -1;
const std::string data_type = arg_parser.get_str("prec");
const std::string data_type = arg_parser.get_str("prec");
const int use_model_sensitive_rmsnorm = arg_parser.get_int("s");
if(data_type == "fp16")
{
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
if(use_model_sensitive_rmsnorm == 0) // 0: for no specific RMSNorm
{
return run<ck_tile::half_t, 0>(arg_parser) ? 0 : -2;
}
else if(use_model_sensitive_rmsnorm == 1) // 1: for T5-like RMSNorm
{
return run<ck_tile::half_t, 1>(arg_parser) ? 0 : -2;
}
}
return -3;

View File

@@ -65,7 +65,8 @@ template <typename XDataType_,
bool kSaveUnquant_,
bool kTwoPass_,
ck_tile::index_t kFusedAdd_ = 0,
ck_tile::index_t kFusedQuant_ = 0>
ck_tile::index_t kFusedQuant_ = 0,
ck_tile::index_t kUseModelSensitiveRMSNorm_ = 0>
struct rmsnorm2d_fwd_traits_
{
using XDataType = ck_tile::remove_cvref_t<XDataType_>;
@@ -74,22 +75,22 @@ struct rmsnorm2d_fwd_traits_
using YScaleDataType = ck_tile::remove_cvref_t<YScaleDataType_>;
using UnquantYDataType = ck_tile::remove_cvref_t<UnquantYDataType_>;
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize;
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0);
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size();
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0);
static constexpr ck_tile::index_t total_warps =
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / WarpSize;
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size();
// num of warps along m
static constexpr ck_tile::index_t BlockWarps_M = []() {
if constexpr(is_warp_per_row)
{
static_assert(WarpSize % ThreadPerBlock_N_ == 0);
return total_warps * (WarpSize / ThreadPerBlock_N_);
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_);
}
else
{
// static_assert(WarpSize % ThreadPerBlock_M_ == 0);
return total_warps / (ThreadPerBlock_N_ / WarpSize);
// static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size());
}
}();
@@ -97,13 +98,13 @@ struct rmsnorm2d_fwd_traits_
static constexpr ck_tile::index_t BlockWarps_N = []() {
if constexpr(is_warp_per_row)
{
static_assert(WarpSize % ThreadPerBlock_N_ == 0);
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
return 1;
}
else
{
static_assert(ThreadPerBlock_N_ % WarpSize == 0);
return ThreadPerBlock_N_ / WarpSize;
static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0);
return ThreadPerBlock_N_ / ck_tile::get_warp_size();
}
}();
@@ -127,8 +128,9 @@ struct rmsnorm2d_fwd_traits_
static constexpr bool kSaveInvRms = kSaveInvRms_;
static constexpr bool kSaveUnquant = kSaveUnquant_;
static constexpr bool kTwoPass = kTwoPass_;
static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_;
static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_;
static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_;
static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_;
static constexpr ck_tile::index_t kUseModelSensitiveRMSNorm = kUseModelSensitiveRMSNorm_;
};
template <typename XDataType_,
@@ -146,7 +148,8 @@ template <typename XDataType_,
bool kSaveUnquant_,
bool kTwoPass_,
int kFusedAdd_,
int kFusedQuant_>
int kFusedQuant_,
int kUseModelSensitiveRMSNorm_>
using traits_ = rmsnorm2d_fwd_traits_<XDataType_,
YDataType_,
SmoothScaleDataType_,
@@ -162,7 +165,8 @@ using traits_ = rmsnorm2d_fwd_traits_<XDataType_,
kSaveUnquant_,
kTwoPass_,
kFusedAdd_,
kFusedQuant_>;
kFusedQuant_,
kUseModelSensitiveRMSNorm_>;
"""
API_COMMON_HEADER = """
@@ -197,7 +201,8 @@ float rmsnorm2d_fwd_(const S& s, A a)
Traits_::kSaveUnquant,
Traits_::kTwoPass,
static_cast<ck_tile::Rmsnorm2dFusedAddEnum>(Traits_::kFusedAdd),
static_cast<ck_tile::Rmsnorm2dFusedQuantEnum>(Traits_::kFusedQuant)>;
static_cast<ck_tile::Rmsnorm2dFusedQuantEnum>(Traits_::kFusedQuant),
static_cast<ck_tile::Rmsnorm2dSensitiveEnum>(Traits_::kUseModelSensitiveRMSNorm)>;
using PipelineProblem =
ck_tile::Rmsnorm2dFwdPipelineProblem<typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::XDataType,
@@ -213,7 +218,13 @@ float rmsnorm2d_fwd_(const S& s, A a)
using OnePassPipeline = ck_tile::Rmsnorm2dFwdPipelineOnePass<PipelineProblem>;
using TwoPassPipeline = ck_tile::Rmsnorm2dFwdPipelineTwoPass<PipelineProblem>;
using Pipeline = std::conditional_t<Traits_::kTwoPass, TwoPassPipeline, OnePassPipeline>;
using T5PassPipeline = ck_tile::Rmsnorm2dFwdPipelineModelSensitiveT5Pass<PipelineProblem>;
using Pipeline = std::conditional_t<
(Traits_::kUseModelSensitiveRMSNorm == 0 || Traits_::kTwoPass), // TODO: consider TwoPass for T5PassPipeline
std::conditional_t<Traits_::kTwoPass, TwoPassPipeline, OnePassPipeline>, // kUseModelSensitiveRMSNorm == 0
T5PassPipeline
>;
using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem<ComputeDataType, YDataType, false, Traits_::kPadN, false>;
using Default2DEpilogue = ck_tile::Default2DEpilogue<Default2DEpilogueProblem>;
@@ -387,12 +398,13 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
F_kTwoPass : bool
F_kFusedAdd : int
F_kFusedQuant : int
F_use_model_sensitive_rmsnorm : int
@property
def trait_name(self) ->str:
t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_SmoothScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {DATA_TYPE_MAP[self.F_UnquantYDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}'
t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveInvRms):5}, {BOOL_MAP(self.F_kSaveUnquant):5}'
t_ += f', {BOOL_MAP(self.F_kTwoPass):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}'
t_ += f', {BOOL_MAP(self.F_kTwoPass):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}, {self.F_use_model_sensitive_rmsnorm:4}'
return t_
# string when calling this kernel
@@ -413,6 +425,7 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
F_add : int
F_sweep : int
F_saveunquant : bool
F_use_model_sensitive_rmsnorm : int
instance_list : List[Any] # List[h_traits]
@property
@@ -426,6 +439,10 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
nnn = nnn + '_' + FUSED_FUSED_SWEEP_STR_MAP[self.F_sweep]
if self.F_saveunquant:
nnn = nnn + '_saveunquant'
if self.F_use_model_sensitive_rmsnorm == 0:
nnn = nnn + '_nsm'
elif self.F_use_model_sensitive_rmsnorm == 1:
nnn = nnn + '_t5ml'
return nnn
@property
@@ -481,9 +498,9 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
elif ins.F_kFusedQuant == 2:
_sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sy == \"{f_sy_type}\" && t.save_unquant == {f_suq})'.format(
f_fused_sweep = ins.F_kFusedQuant, f_sy_type=ins.F_YScaleDataType, f_suq=BOOL_MAP(ins.F_kSaveUnquant))
_cond = '((a.n % {f_vec_n} == 0) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}))'.format(
_cond = '((a.n % {f_vec_n} == 0) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}) && (t.use_model_sensitive_rmsnorm == {f_use_model_sensitive_rmsnorm}) )'.format(
f_vec_n = ins.F_Vector_N, f_fused_add = ins.F_kFusedAdd,
f_sweep_cond = _sweep_cond)
f_sweep_cond = _sweep_cond, f_use_model_sensitive_rmsnorm = ins.F_use_model_sensitive_rmsnorm)
inner_str += self.API_INNER_CASE.format(F_if = get_if_str(idx_in_n, len_in_n, False),
F_VEC_COND = _cond, F_instance_func=ins.call_name)
#inner_str = inner_str + vec_str
@@ -516,85 +533,149 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused (smooth) dynamic quant
bool_list = [False, True]
# rm rn tm tn vn pd mv unquant 2p add sweep
h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 8, 8, 8, True, False, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 4, True, False, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 1, True, False, False, False, 0, 0)],
'128' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 8, True, False, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 2, True, False, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 1, True, False, False, False, 0, 0)],
'256' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 4, True, False, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 2, True, False, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 1, True, False, False, False, 0, 0)],
'512' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 8, True, False, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 4, True, False, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 2, True, False, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 4, 64, 1, True, False, False, False, 0, 0)],
'640' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 64, 2, True, False, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 128, 1, True, False, False, False, 0, 0)],
'768' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 4, 64, 4, True, False, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 4, 64, 2, True, False, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 4, 64, 1, True, False, False, False, 0, 0)],
'1024' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 2, 64, 8, True, False, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 2, 64, 4, True, False, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 2, 64, 2, True, False, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 1, True, False, False, False, 0, 0)],
'1536' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 4, 64, 8, True, False, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 2, 128, 4, True, False, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 2, True, False, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 1, True, False, False, False, 0, 0)],
'2048' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1, 256, 8, True, False, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 4, True, False, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 2, True, False, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1, 256, 1, True, False, False, False, 0, 0)],
'3072' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 128, 8, True, False, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 4, True, False, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 2, True, False, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 1, True, False, False, False, 0, 0)],
'4096' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 8, True, False, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1,1024, 2, True, False, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, False, 0, 0)],
'6144' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 8, True, False, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 512, 4, True, False, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 2, True, False, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1,1024, 1, True, False, False, False, 0, 0)],
'8192' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 8, True, False, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 512, 4, True, False, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 2, True, False, False, False, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1,1024, 1, True, False, False, False, 0, 0)],
'big' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1,1024, 8, True, False, False, True, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, True, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 1, 256, 2, True, False, False, True, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, True, 0, 0)]}
h_trait_dicts = {
0: {
# rm rn tm tn vn pd mv unquant 2p add sweep srm
'64' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 8, 8, 8, True, False, False, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 4, True, False, False, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 1, True, False, False, False, 0, 0, 0)],
'128' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 8, True, False, False, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 2, True, False, False, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 1, True, False, False, False, 0, 0, 0)],
'256' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 4, True, False, False, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 2, True, False, False, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 1, True, False, False, False, 0, 0, 0)],
'512' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 8, True, False, False, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 4, True, False, False, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 2, True, False, False, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 4, 64, 1, True, False, False, False, 0, 0, 0)],
'640' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 64, 2, True, False, False, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 128, 1, True, False, False, False, 0, 0, 0)],
'768' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 4, 64, 4, True, False, False, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 4, 64, 2, True, False, False, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 4, 64, 1, True, False, False, False, 0, 0, 0)],
'1024' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 2, 64, 8, True, False, False, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 2, 64, 4, True, False, False, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 2, 64, 2, True, False, False, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 1, True, False, False, False, 0, 0, 0)],
'1536' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 4, 64, 8, True, False, False, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 2, 128, 4, True, False, False, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 2, True, False, False, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 1, True, False, False, False, 0, 0, 0)],
'2048' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1, 256, 8, True, False, False, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 4, True, False, False, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 2, True, False, False, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1, 256, 1, True, False, False, False, 0, 0, 0)],
'3072' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 128, 8, True, False, False, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 4, True, False, False, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 2, True, False, False, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 1, True, False, False, False, 0, 0, 0)],
'4096' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 8, True, False, False, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1,1024, 2, True, False, False, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, False, 0, 0, 0)],
'6144' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 8, True, False, False, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 512, 4, True, False, False, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 2, True, False, False, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1,1024, 1, True, False, False, False, 0, 0, 0)],
'8192' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 8, True, False, False, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 512, 4, True, False, False, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 2, True, False, False, False, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1,1024, 1, True, False, False, False, 0, 0, 0)],
'big' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1,1024, 8, True, False, False, True, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, True, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 1, 256, 2, True, False, False, True, 0, 0, 0),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, True, 0, 0, 0)]
},
1: {
# rm rn tm tn vn pd mv unquant 2p add sweep srm
'64' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 8, 8, 8, True, False, False, False, 0, 0, 1),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 4, True, False, False, False, 0, 0, 1),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 1, True, False, False, False, 0, 0, 1)],
'128' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 8, True, False, False, False, 0, 0, 1),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 2, True, False, False, False, 0, 0, 1),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 1, True, False, False, False, 0, 0, 1)],
'256' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 8, 32, 8, True, False, False, False, 0, 0, 1),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 4, True, False, False, False, 0, 0, 1),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 2, True, False, False, False, 0, 0, 1),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 1, True, False, False, False, 0, 0, 1)],
'512' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 8, True, False, False, False, 0, 0, 1),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 4, True, False, False, False, 0, 0, 1),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 2, True, False, False, False, 0, 0, 1),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 4, 64, 1, True, False, False, False, 0, 0, 1)],
'640' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 2, 128, 8, True, False, False, False, 0, 0, 1),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 64, 2, True, False, False, False, 0, 0, 1),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 128, 1, True, False, False, False, 0, 0, 1)],
'768' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 2, 128, 8, True, False, False, False, 0, 0, 1),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 4, 64, 4, True, False, False, False, 0, 0, 1),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 4, 64, 2, True, False, False, False, 0, 0, 1),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 4, 64, 1, True, False, False, False, 0, 0, 1)],
'1024' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 2, 128, 8, True, False, False, False, 0, 0, 1),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 2, 64, 4, True, False, False, False, 0, 0, 1),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 2, 64, 2, True, False, False, False, 0, 0, 1),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 1, True, False, False, False, 0, 0, 1)],
'1536' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1, 256, 8, True, False, False, False, 0, 0, 1),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 2, 128, 4, True, False, False, False, 0, 0, 1),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 2, True, False, False, False, 0, 0, 1),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 1, True, False, False, False, 0, 0, 1)],
'2048' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1, 256, 8, True, False, False, False, 0, 0, 1),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 4, True, False, False, False, 0, 0, 1),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 2, True, False, False, False, 0, 0, 1),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1, 256, 1, True, False, False, False, 0, 0, 1)],
'3072' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 8, True, False, False, False, 0, 0, 1),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 4, True, False, False, False, 0, 0, 1),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 2, True, False, False, False, 0, 0, 1),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 1, True, False, False, False, 0, 0, 1)],
'4096' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 8, True, False, False, False, 0, 0, 1),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, False, 0, 0, 1),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1,1024, 2, True, False, False, False, 0, 0, 1),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, False, 0, 0, 1)],
'6144' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 8, True, False, False, False, 0, 0, 1),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 512, 4, True, False, False, False, 0, 0, 1),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 2, True, False, False, False, 0, 0, 1),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1,1024, 1, True, False, False, False, 0, 0, 1)],
'8192' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 8, True, False, False, False, 0, 0, 1),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 512, 4, True, False, False, False, 0, 0, 1),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 2, True, False, False, False, 0, 0, 1),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1,1024, 1, True, False, False, False, 0, 0, 1)],
'big' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1,1024, 8, True, False, False, True, 0, 0, 1),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, True, 0, 0, 1),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 1, 256, 2, True, False, False, True, 0, 0, 1),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, True, 0, 0, 1)]
}
}
total_blob = list()
for hs_key in h_trait_dict:
hs = h_trait_dict[hs_key]
current_n = hs[0].F_Repeat_N * hs[0].F_ThreadPerBlock_N * hs[0].F_Vector_N
for dtype, scale_type, fused_add, fused_quant, save_unquant in itertools.product(dtype_list, scale_list, fused_add_list, fused_sweep_list, bool_list):
prec_i, prec_o = dtype.split(',')
scale_sm, scale_y = scale_type.split(',')
if prec_o in dynamic_quant_out_dtype and fused_quant != 1 and fused_quant != 2:
continue # skip non dynamic quant case
if (fused_quant == 1 or fused_quant == 2) and hs_key == 'big':
continue
if (fused_quant == 0 and save_unquant == True):
continue # save_unquant should always be false when there is no quant enabled
current_hs = list()
for chs_ in hs:
h_ = copy.copy(chs_) # copy the base instance out
h_.F_XDataType = prec_i
h_.F_YDataType = prec_o
h_.F_SmoothScaleDataType = scale_sm
h_.F_YScaleDataType = scale_y
h_.F_UnquantYDataType = prec_i
h_.F_kFusedAdd = fused_add
h_.F_kFusedQuant = fused_quant
h_.F_kSaveUnquant = save_unquant
current_hs.append(h_) # + "\n"
#f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_
current_n_str = 'big' if hs_key == 'big' else current_n
total_blob.append(h_instance(dtype, current_n_str, fused_add, fused_quant, save_unquant, current_hs))
for model_sensitive_flag in [0, 1]: # 0: default; 1: model sensitive
current_trait_dict = h_trait_dicts[model_sensitive_flag]
for hs_key in current_trait_dict:
hs = current_trait_dict[hs_key]
current_n = hs_key
for dtype, scale_type, fused_add, fused_quant, save_unquant in itertools.product(dtype_list, scale_list, fused_add_list, fused_sweep_list, bool_list):
prec_i, prec_o = dtype.split(',')
scale_sm, scale_y = scale_type.split(',')
if prec_o in dynamic_quant_out_dtype and fused_quant != 1 and fused_quant != 2:
continue # skip non dynamic quant case
if (fused_quant == 1 or fused_quant == 2) and hs_key == 'big':
continue
if (fused_quant == 0 and save_unquant == True):
continue # save_unquant should always be false when there is no quant enabled
current_hs = list()
for chs_ in hs:
h_ = copy.copy(chs_) # copy the base instance out
h_.F_XDataType = prec_i
h_.F_YDataType = prec_o
h_.F_SmoothScaleDataType = scale_sm
h_.F_YScaleDataType = scale_y
h_.F_UnquantYDataType = prec_i
h_.F_kFusedAdd = fused_add
h_.F_kFusedQuant = fused_quant
h_.F_kSaveUnquant = save_unquant
current_hs.append(h_) # + "\n"
#f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_
current_n_str = 'big' if hs_key == 'big' else current_n
total_blob.append(h_instance(dtype, current_n_str, fused_add, fused_quant, save_unquant, h_.F_use_model_sensitive_rmsnorm, current_hs))
return total_blob
def list_blobs(self) -> None:
@@ -712,4 +793,4 @@ if __name__ == "__main__":
if args.list_blobs:
list_blobs(args)
else:
gen_blobs(args)
gen_blobs(args)

View File

@@ -52,7 +52,8 @@ auto create_args(int argc, char* argv[])
.insert("fadd", "0", "fused-add, 0:no fused add, 1:preadd+store, 2:preadd only")
.insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant")
.insert("warmup", "5", "cold iter")
.insert("repeat", "20", "hot iter");
.insert("repeat", "20", "hot iter")
.insert("s", "0", "sensitive model mode, 0: for no specific model, 1: for T5-like model");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
@@ -66,15 +67,16 @@ template <typename InDataType,
bool SaveUnquant>
bool run(const ck_tile::ArgParser& arg_parser)
{
ck_tile::index_t m = arg_parser.get_int("m");
ck_tile::index_t n = arg_parser.get_int("n");
float epsilon = arg_parser.get_float("e");
int kname = arg_parser.get_int("kname");
int do_validation = arg_parser.get_int("v");
int fused_add = arg_parser.get_int("fadd");
int fused_quant = arg_parser.get_int("fquant");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
ck_tile::index_t m = arg_parser.get_int("m");
ck_tile::index_t n = arg_parser.get_int("n");
float epsilon = arg_parser.get_float("e");
int kname = arg_parser.get_int("kname");
int do_validation = arg_parser.get_int("v");
int fused_add = arg_parser.get_int("fadd");
int fused_quant = arg_parser.get_int("fquant");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
const int use_model_sensitive_rmsnorm = arg_parser.get_int("s");
ck_tile::index_t x_stride = arg_parser.get_int("x_stride");
if(x_stride < 0)
@@ -194,10 +196,17 @@ bool run(const ck_tile::ArgParser& arg_parser)
std::cout << "[" << prec_str << "]"
<< " m:" << m << ", n:" << n << ", x_stride:" << x_stride
<< ", xr_stride:" << xr_stride << ", y_stride:" << y_stride
<< ", yr_stride:" << yr_stride << std::flush;
<< ", yr_stride:" << yr_stride << ", s:" << use_model_sensitive_rmsnorm << std::flush;
rmsnorm2d_fwd_traits traits{
prec_i, prec_o, prec_sm, prec_sy, SaveRms, SaveUnquant, fused_add, fused_quant};
rmsnorm2d_fwd_traits traits{prec_i,
prec_o,
prec_sm,
prec_sy,
SaveRms,
SaveUnquant,
fused_add,
fused_quant,
use_model_sensitive_rmsnorm};
rmsnorm2d_fwd_args args{x_buf.GetDeviceBuffer(),
fused_add != 0 ? x_residual_buf.GetDeviceBuffer() : nullptr,

View File

@@ -64,6 +64,8 @@ struct rmsnorm2d_fwd_traits
bool save_unquant;
int fused_add; // 0:no-add, 1:pre-add-store, 2:pre-add
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
int use_model_sensitive_rmsnorm = 0; // 0: Use default RMSNorm; 1: Use T5-like implementation
};
float rmsnorm2d_fwd(rmsnorm2d_fwd_traits, rmsnorm2d_fwd_args, const ck_tile::stream_config&);

View File

@@ -1,37 +1,74 @@
#!/bin/sh
EXE="$(find . -name tile_rmsnorm2d_fwd -type f | head -n 1)"
$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=bf16 -repeat=1000
# 0: for no specific RMSNorm
$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=0
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=fp16 -repeat=1000
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=0
# 1: for T5-like RMSNorm
$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 -s=1
$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1
$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec_i=fp16 -repeat=1000 -s=1

View File

@@ -5,29 +5,32 @@ for fquant in "" "-fquant=1 -prec_o=int8" "-fquant=2 -prec_o=int8" "-fquant=1 -p
"-fquant=1 -prec_o=int8 -save_unquant=1" "-fquant=2 -prec_o=int8 -save_unquant=1" "-fquant=1 -prec_o=fp8 -save_unquant=1" "-fquant=2 -prec_o=fp8 -save_unquant=1"; do
for pr_i in "fp16" "bf16" ; do
for fadd in "0" "1"; do
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=99 -n=13
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=17 -n=16
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=100
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=4 -n=128
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=80 -n=127
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=22 -n=255 -stride=256
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=599
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=19 -n=512
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=33 -n=313 -stride=1000
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=11 -n=510
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=171 -n=676 -stride=818
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=91 -n=636
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=12 -n=768 -stride=800
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=100 -n=766 -stride=812
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=31 -n=1024
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=64 -n=1000 -stride=1004
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=8 -n=1501
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=1826
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=5 -n=2040
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=7 -n=2734
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=3182
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=9 -n=4096
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=8192
# 0: for no specific RMSNorm; 1: for T-5 like RMSNorm
for s in "0" "1"; do
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=99 -n=13
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=17 -n=16
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=100
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=4 -n=128
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=80 -n=127
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=22 -n=255 -stride=256
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=7 -n=599
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=19 -n=512
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=33 -n=313 -stride=1000
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=11 -n=510
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=171 -n=676 -stride=818
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=91 -n=636
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=12 -n=768 -stride=800
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=100 -n=766 -stride=812
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=31 -n=1024
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=64 -n=1000 -stride=1004
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=8 -n=1501
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=3 -n=1826
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=5 -n=2040
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=7 -n=2734
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=3182
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=9 -n=4096
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=3 -n=8192
done
done
done
done
@@ -36,8 +39,11 @@ done
for fquant in ""
for pr_i in "fp16" "bf16" ; do
for fadd in "0" "1"; do
$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=1 -n=10547
# 0: for no specific RMSNorm; 1: for T-5 like RMSNorm
for s in "0" "1"; do
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=10547
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134
done
done
done
done

View File

@@ -49,22 +49,22 @@ struct smoothquant_traits_
{
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize;
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0);
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size();
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0);
static constexpr ck_tile::index_t total_warps =
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / WarpSize;
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size();
// num of warps along m
static constexpr ck_tile::index_t BlockWarps_M = []() {
if constexpr(is_warp_per_row)
{
static_assert(WarpSize % ThreadPerBlock_N_ == 0);
return total_warps * (WarpSize / ThreadPerBlock_N_);
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_);
}
else
{
// static_assert(WarpSize % ThreadPerBlock_M_ == 0);
return total_warps / (ThreadPerBlock_N_ / WarpSize);
// static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size());
}
}();
@@ -72,13 +72,13 @@ struct smoothquant_traits_
static constexpr ck_tile::index_t BlockWarps_N = []() {
if constexpr(is_warp_per_row)
{
static_assert(WarpSize % ThreadPerBlock_N_ == 0);
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
return 1;
}
else
{
static_assert(ThreadPerBlock_N_ % WarpSize == 0);
return ThreadPerBlock_N_ / WarpSize;
static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0);
return ThreadPerBlock_N_ / ck_tile::get_warp_size();
}
}();

View File

@@ -35,7 +35,20 @@ auto create_args(int argc, char* argv[])
.insert("e", "8", "number of num_experts")
.insert("k", "4", "topk")
.insert("unit", "32", "unit_size")
#if MOE_SORTING_FMOE_2D_BUF
.insert("moe_buf_interm_dim", "0", "interm_dim(col) of the following fmoe buf")
.insert(
"moe_buf_elem_bytes", "2", "fmoe buf element byte size, 1:8bit, 2:16bit, 4:32bit...")
#else
.insert("moe_buf_size", "0", "moe_buf_size")
#endif
.insert("ci",
"1",
"clear workspace inside API or not(if \"0\", require manually clear outside)")
.insert(
"dispatch",
"0",
"dispatch policy. 0:automatically pick up kernel, 1:use single kernel, 2:use mp kernel")
.insert("local_eid",
"-1",
"a list of experts enabled as local expert. e.g. \"0,1,4,5\"\n"
@@ -88,10 +101,17 @@ bool test_moe_sorting(ck_tile::ArgParser args)
int topk = args.get_int("k");
int seed = args.get_int("seed");
int unit_size = args.get_int("unit");
int64_t moe_buf_size = static_cast<int64_t>(args.get_uint64("moe_buf_size"));
int kname = args.get_int("kname");
int warmup = args.get_int("warmup");
int repeat = args.get_int("repeat");
#if MOE_SORTING_FMOE_2D_BUF
int moe_buf_interm_dim = args.get_int("moe_buf_interm_dim");
int moe_buf_elem_bytes = args.get_int("moe_buf_elem_bytes");
#else
int64_t moe_buf_size = static_cast<int64_t>(args.get_uint64("moe_buf_size"));
#endif
int kname = args.get_int("kname");
int warmup = args.get_int("warmup");
int repeat = args.get_int("repeat");
bool clear_inside = args.get_int("ci") != 0;
int dispatch_policy = args.get_int("dispatch");
int max_output_ids =
ck_tile::integer_least_multiple(topk * tokens + num_experts * unit_size - topk, unit_size);
@@ -149,11 +169,26 @@ bool test_moe_sorting(ck_tile::ArgParser args)
ck_tile::HostTensor<IndexType> sorted_ids_host({max_output_ids}, {1});
ck_tile::HostTensor<WeightType> sorted_weights_host({max_output_ids}, {1});
ck_tile::HostTensor<IndexType> sorted_expert_ids_host({max_output_ids / unit_size}, {1});
ck_tile::HostTensor<IndexType> sorted_id_cnt_host({1}, {1});
// for simplicity, below buffer allocate 2 dword
ck_tile::HostTensor<IndexType> sorted_id_cnt_host({2}, {1});
#if MOE_SORTING_FMOE_2D_BUF
ck_tile::HostTensor<int8_t> moe_buf_host(
{static_cast<std::size_t>(is_local_token ? local_tokens : tokens) * moe_buf_interm_dim *
moe_buf_elem_bytes});
auto moe_buf_bytes = moe_buf_interm_dim == 0 ? static_cast<std::size_t>(0)
: moe_buf_host.get_element_space_size_in_bytes();
#else
ck_tile::HostTensor<float> moe_buf_host({moe_buf_size});
auto moe_buf_bytes = moe_buf_size == 0 ? static_cast<std::size_t>(0)
: moe_buf_host.get_element_space_size_in_bytes();
#endif
ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(weights_host);
#if MOE_SORTING_FMOE_2D_BUF
ck_tile::FillUniformDistribution<int8_t>{-.5f, .5f}(moe_buf_host);
#else
ck_tile::FillUniformDistribution<WeightType>{-.5f, .5f}(moe_buf_host);
#endif
topid_unique_gen<IndexType>(topk_ids_host.mData, tokens, topk, num_experts, seed);
ck_tile::DeviceMem topk_ids_dev(topk_ids_host.get_element_space_size_in_bytes());
@@ -176,7 +211,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
topk_ids_dev.ToDevice(topk_ids_host.data());
weights_dev.ToDevice(weights_host.data());
if(moe_buf_size > 0)
if(moe_buf_bytes > 0)
{
moe_buf_dev.ToDevice(moe_buf_host.data());
}
@@ -184,29 +219,31 @@ bool test_moe_sorting(ck_tile::ArgParser args)
local_expert_masking_dev.ToDevice(local_expert_masking_host.data());
// if return zero, means no need workspace, can set moe_sorting_args.p_ws to nullptr
ck_tile::index_t workspace_size = moe_sorting_get_workspace_size(tokens, num_experts, topk);
ck_tile::index_t workspace_size =
moe_sorting_get_workspace_size(tokens, num_experts, topk, dispatch_policy);
ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0);
if(workspace_size != 0)
if(workspace_size != 0 && clear_inside == false)
moe_sorting_ws.SetZero(); // note, clear here!!!!
moe_sorting_trait trait{index_prec, weight_prec, local_expert_masking};
moe_sorting_trait trait{
index_prec, weight_prec, local_expert_masking, clear_inside, dispatch_policy};
moe_sorting_args karg{topk_ids_dev.GetDeviceBuffer(),
weights_dev.GetDeviceBuffer(),
local_expert_masking ? local_expert_masking_dev.GetDeviceBuffer()
: nullptr,
is_local_token ? local_tokens_dev.GetDeviceBuffer() : nullptr,
sorted_ids_dev.GetDeviceBuffer(),
sorted_weights_dev.GetDeviceBuffer(),
sorted_expert_ids_dev.GetDeviceBuffer(),
sorted_id_cnt_dev.GetDeviceBuffer(),
moe_buf_size > 0 ? moe_buf_dev.GetDeviceBuffer() : nullptr,
workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr,
tokens,
unit_size,
num_experts,
topk,
static_cast<ck_tile::long_index_t>(moe_buf_size * sizeof(float))};
moe_sorting_args karg
{
topk_ids_dev.GetDeviceBuffer(), weights_dev.GetDeviceBuffer(),
local_expert_masking ? local_expert_masking_dev.GetDeviceBuffer() : nullptr,
is_local_token ? local_tokens_dev.GetDeviceBuffer() : nullptr,
sorted_ids_dev.GetDeviceBuffer(), sorted_weights_dev.GetDeviceBuffer(),
sorted_expert_ids_dev.GetDeviceBuffer(), sorted_id_cnt_dev.GetDeviceBuffer(),
moe_buf_bytes > 0 ? moe_buf_dev.GetDeviceBuffer() : nullptr,
workspace_size != 0 ? moe_sorting_ws.GetDeviceBuffer() : nullptr, tokens, unit_size,
num_experts, topk,
#if MOE_SORTING_FMOE_2D_BUF
moe_buf_interm_dim, moe_buf_elem_bytes
#else
static_cast<ck_tile::long_index_t>(moe_buf_size * sizeof(float))
#endif
};
ck_tile::stream_config sc{nullptr,
true,
@@ -219,7 +256,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
#if 0
{
ck_tile::HostTensor<char> ws_host({workspace_size}, {1});
ck_tile::HostTensor<char> ws_host({workspace_size}, {1});
moe_sorting_ws.FromDevice(ws_host.data());
int * p_mesh = reinterpret_cast<int*>(ws_host.data());
@@ -268,7 +305,12 @@ bool test_moe_sorting(ck_tile::ArgParser args)
}
#endif
printf("[%s|%s]tokens:%d", index_prec.c_str(), weight_prec.c_str(), tokens);
printf("[%s|%s|%s|%d]tokens:%d",
index_prec.c_str(),
weight_prec.c_str(),
workspace_size == 0 ? "cx" : (clear_inside ? "ci" : "co"),
dispatch_policy,
tokens);
if(is_local_token)
{
printf("(%d)", local_tokens);
@@ -280,6 +322,19 @@ bool test_moe_sorting(ck_tile::ArgParser args)
printf("local_eid:%s, ", args.get_str("local_eid").c_str());
}
if(moe_buf_bytes > 0)
{
#if MOE_SORTING_FMOE_2D_BUF
printf("moe_buf:%lu(%d,%d), ",
static_cast<uint64_t>(moe_buf_bytes),
moe_buf_interm_dim,
moe_buf_elem_bytes);
#else
printf("moe_buf:%lu, ", static_cast<uint64_t>(moe_buf_bytes));
#endif
}
if(ms < 0)
printf("not supported\n");
else
@@ -294,7 +349,7 @@ bool test_moe_sorting(ck_tile::ArgParser args)
sorted_weights_dev.FromDevice(sorted_weights_host.data());
sorted_expert_ids_dev.FromDevice(sorted_expert_ids_host.data());
sorted_id_cnt_dev.FromDevice(sorted_id_cnt_host.data());
if(moe_buf_size > 0)
if(moe_buf_bytes > 0)
{
moe_buf_dev.FromDevice(moe_buf_host.data());
}
@@ -340,6 +395,16 @@ bool test_moe_sorting(ck_tile::ArgParser args)
std::string("OUT Error: Incorrect eid!"),
1e-6,
1e-6);
// if(is_local_token)
{
auto t_ = is_local_token ? local_tokens : tokens;
bool _f = t_ == sorted_id_cnt_host.mData[1];
rtn &= _f;
if(!_f)
{
printf("not equal token buffer pad %d(%d)\n", t_, sorted_id_cnt_host.mData[1]);
}
}
}
else
{
@@ -347,9 +412,13 @@ bool test_moe_sorting(ck_tile::ArgParser args)
rtn = false;
}
if(moe_buf_size)
if(moe_buf_bytes)
{
#if MOE_SORTING_FMOE_2D_BUF
ck_tile::HostTensor<int8_t> moe_buf_ref({moe_buf_bytes});
#else
ck_tile::HostTensor<WeightType> moe_buf_ref({moe_buf_size});
#endif
rtn &= ck_tile::check_err(
moe_buf_host, moe_buf_ref, std::string("OUT Error: Incorrect zero buf!"), 0, 0);
}

View File

@@ -175,7 +175,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
}
}
#else
if(moe_sorting_get_workspace_size(a.tokens, a.num_experts, a.topk) != 0)
if(moe_sorting_get_workspace_size(a.tokens, a.num_experts, a.topk, t.dispatch_policy) != 0)
{
return moe_sorting_mp(t, a, s);
}
@@ -293,6 +293,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
maybe_clear_workspace, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, true), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, true), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, true)); \
@@ -302,6 +303,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
maybe_clear_workspace, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, true, false), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, true, false), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, true, false)); \
@@ -314,6 +316,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
{ \
float ave_time = \
ck_tile::launch_kernel(s, \
maybe_clear_workspace, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, true), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, true), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, true)); \
@@ -323,6 +326,7 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
{ \
float ave_time = ck_tile::launch_kernel( \
s, \
maybe_clear_workspace, \
MOE_SORTING_MP_0(mesh_type_, token_vec_0_, false, false), \
MOE_SORTING_MP_1(mesh_type_, token_vec_1_, false, false), \
MOE_SORTING_MP_23(mesh_type_, token_vec_23_, false, false)); \
@@ -330,6 +334,17 @@ float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_confi
} \
}
#define MOR_SORTING_CLEAR_WS_DISPATCH_(is_local_token_, block_size_, occu_) \
[&]() { \
using problem_ = \
ck_tile::MoeSortingClearWorkspaceProblem<is_local_token_, block_size_, occu_>; \
using kernel = ck_tile::MoeSortingClearWorkspaceKernel<problem_>; \
auto kargs = kernel::MakeKargs(a); \
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(a); \
return ck_tile::make_kernel<kernel::BLOCK_SIZE>(kernel{}, grids, blocks, 0, kargs); \
}()
float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s)
{
bool is_local_token = a.p_local_tokens != nullptr;
@@ -338,6 +353,22 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co
using ms_index_t = ck_tile::index_t;
using ms_weight_type = float;
auto maybe_clear_workspace = [=](const ck_tile::stream_config& s_) {
if(t.clear_workspace_inside_api)
{
if(is_local_token)
{
auto k = MOR_SORTING_CLEAR_WS_DISPATCH_(true, 1024, 1);
k(s_);
}
else
{
auto k = MOR_SORTING_CLEAR_WS_DISPATCH_(false, 1024, 1);
k(s_);
}
}
};
if(ck_tile::impl::moe_sorting_get_smem_size_p23(a.num_experts) >
ck_tile::get_smem_capacity())
{
@@ -345,6 +376,7 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co
if(t.local_expert_masking)
{
float ave_time = ck_tile::launch_kernel(s,
maybe_clear_workspace,
MOE_SORTING_MP_0(ms_index_t, 1, true),
MOE_SORTING_MP_1(ms_index_t, 1, true),
MOE_SORTING_MP_2(ms_index_t, 1, true),
@@ -354,6 +386,7 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co
else
{
float ave_time = ck_tile::launch_kernel(s,
maybe_clear_workspace,
MOE_SORTING_MP_0(ms_index_t, 1, false),
MOE_SORTING_MP_1(ms_index_t, 1, false),
MOE_SORTING_MP_2(ms_index_t, 1, false),
@@ -405,7 +438,7 @@ float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_co
return -1;
}
int moe_sorting_get_workspace_size(int tokens, int num_experts, int topk)
int moe_sorting_get_workspace_size(int tokens, int num_experts, int topk, int dispatch_policy)
{
return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk);
return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk, dispatch_policy);
}

View File

@@ -10,8 +10,14 @@
struct moe_sorting_trait
{
std::string index_type;
std::string weight_type; // currently always float
bool local_expert_masking; // if mask experts as local expert
std::string weight_type; // currently always float
bool local_expert_masking; // if mask experts as local expert
bool clear_workspace_inside_api; // if true, no need clear workspace outsize (will take care of
// it inside API)
int dispatch_policy; // 0 - let the API choose kernel for you. 1 - always use single kerenl. 2 -
// always use mp kernel NOTE: moe_sorting_get_workspace_size() need use
// same dispatch_policy value. it will be undefined behavior if ppl using
// different value when get ws and call the kernel
};
struct moe_sorting_args : public ck_tile::MoeSortingHostArgs
@@ -22,6 +28,6 @@ struct moe_sorting_args : public ck_tile::MoeSortingHostArgs
// if return non zero, means need workspace, you need to allocate a GPU buffer
// and set to moe_sorting_args.p_ws
// NOTE: workspace size are required to clear zero before use the API
int moe_sorting_get_workspace_size(int tokens, int num_experts, int topk);
int moe_sorting_get_workspace_size(int tokens, int num_experts, int topk, int dispatch_policy);
float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s);
float moe_sorting_mp(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s);

View File

@@ -1,7 +1,9 @@
# #!/bin/sh
EXE=./build/bin/tile_example_moe_sorting
MOE_BUF="12"
if [ "x$MOE_BUF" = "x1" ] ; then
$EXE -t=80 -e=17 -moe_buf_size=16
$EXE -t=111 -e=117 -moe_buf_size=4
$EXE -t=1000 -e=55 -moe_buf_size=1024
@@ -42,3 +44,46 @@ $EXE -t=23 -local_t=9 -e=1 -k=1
$EXE -t=7 -local_t=0 -e=89 -k=1 -local_eid=0,8,12,33
$EXE -t=61 -local_t=0 -e=333 -k=99 -local_eid=0,8,12,33
$EXE -t=133940 -local_t=111921 -e=256 -k=17 -moe_buf_size=133940
else
$EXE -t=80 -e=17 -moe_buf_interm_dim=16 -moe_buf_elem_bytes=4
$EXE -t=111 -e=117 -moe_buf_interm_dim=4 -moe_buf_elem_bytes=4
$EXE -t=1000 -e=55 -moe_buf_interm_dim=1024 -moe_buf_elem_bytes=1
$EXE -t=99 -e=120 -moe_buf_interm_dim=10244 -moe_buf_elem_bytes=2
$EXE -t=175 -e=64 -k=8
$EXE -t=65 -e=8 -k=2
$EXE -t=1 -e=25
$EXE -t=31 -e=19 -k=15
$EXE -t=81 -e=37 -k=7
$EXE -t=23 -e=1 -k=1
$EXE -t=127 -e=99 -k=19
$EXE -t=71 -e=11 -k=11
$EXE -t=1 -e=1 -k=1
$EXE -t=99 -e=2 -k=1
$EXE -t=333 -e=99 -k=13
$EXE -t=11 -e=256 -k=5
$EXE -t=64 -e=455 -k=8
$EXE -t=777 -e=802 -k=99
$EXE -t=4097 -e=906 -k=51
$EXE -t=128 -e=32 -k=5 -local_t=6 -moe_buf_interm_dim=262144
$EXE -t=13 -e=64 -k=3 -local_eid=4,5,6,7,8,9,10,11
$EXE -t=99 -e=33 -k=9 -local_eid=6,10,11,15,19
$EXE -t=80 -e=99 -k=10 -local_eid=0,8,12,33
$EXE -t=11 -e=256 -k=5 -local_eid=99,110,129
$EXE -t=128 -e=128 -k=6 -moe_buf_interm_dim=163840 -moe_buf_elem_bytes=1
$EXE -t=8192 -e=32 -k=5 -local_t=11 -moe_buf_interm_dim=163840
$EXE -t=8192 -e=32 -k=8 -local_t=12 -moe_buf_interm_dim=163840 -moe_buf_elem_bytes=1
$EXE -t=8192 -e=256 -k=5 -local_t=13 -moe_buf_interm_dim=163840
$EXE -t=8192 -e=256 -k=8 -local_t=8 -moe_buf_interm_dim=163840
$EXE -t=163840 -e=256 -k=8 -local_t=4 -moe_buf_interm_dim=163840 -moe_buf_elem_bytes=4
$EXE -t=12 -local_t=3 -e=256 -k=5 -local_eid=9,10,199,145
$EXE -t=67 -local_t=9 -e=555 -k=5 -local_eid=19,23,24,25,26,99
$EXE -t=99 -local_t=93 -e=121 -local_t=4 -moe_buf_interm_dim=10244
$EXE -t=536 -local_t=345 -e=802 -k=99
$EXE -t=331 -local_t=39 -e=83 -k=33
$EXE -t=765 -local_t=654 -e=783 -k=8
$EXE -t=23 -local_t=9 -e=1 -k=1
$EXE -t=7 -local_t=0 -e=89 -k=1 -local_eid=0,8,12,33
$EXE -t=61 -local_t=0 -e=333 -k=99 -local_eid=0,8,12,33
$EXE -t=133940 -local_t=111921 -e=256 -k=17 -local_t=2 -moe_buf_interm_dim=133940 -moe_buf_elem_bytes=1
fi

View File

@@ -38,22 +38,22 @@ struct moe_smoothquant_traits_
using InputType = ck_tile::remove_cvref_t<InputType_>;
using OutputType = ck_tile::remove_cvref_t<OutputType_>;
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize;
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0);
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size();
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0);
static constexpr ck_tile::index_t total_warps =
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / WarpSize;
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size();
// num of warps along m
static constexpr ck_tile::index_t BlockWarps_M = []() {
if constexpr(is_warp_per_row)
{
static_assert(WarpSize % ThreadPerBlock_N_ == 0);
return total_warps * (WarpSize / ThreadPerBlock_N_);
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_);
}
else
{
// static_assert(WarpSize % ThreadPerBlock_M_ == 0);
return total_warps / (ThreadPerBlock_N_ / WarpSize);
// static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size());
}
}();
@@ -61,13 +61,13 @@ struct moe_smoothquant_traits_
static constexpr ck_tile::index_t BlockWarps_N = []() {
if constexpr(is_warp_per_row)
{
static_assert(WarpSize % ThreadPerBlock_N_ == 0);
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
return 1;
}
else
{
static_assert(ThreadPerBlock_N_ % WarpSize == 0);
return ThreadPerBlock_N_ / WarpSize;
static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0);
return ThreadPerBlock_N_ / ck_tile::get_warp_size();
}
}();

View File

@@ -6,7 +6,8 @@
int fused_moe_get_workspace_size(int tokens, int num_experts, int topk)
{
return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk);
return ck_tile::moe_sorting_get_workspace_size(
tokens, num_experts, topk, 0 /*dispatch policy*/);
}
float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_config& s)
@@ -24,23 +25,28 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf
}();
auto t0 = fused_moesorting_trait{"int32", "fp32", t.local_expert_masking};
auto a0 = fused_moesorting_args{
a.topk_ids_ptr, // const void* p_topk_ids;
a.topk_weight_ptr, // const void* p_weights;
a.local_expert_mask_ptr, // const void* p_local_expert_mask;
a.local_tokens,
a.sorted_token_ids_ptr, // void* p_sorted_token_ids;
a.sorted_weight_ptr, // void* p_sorted_weights;
a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids;
a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad;
a.o_ptr, // void* p_moe_buf;
a.ws_ptr, // void* p_ws;
a.num_tokens, // index_t tokens;
a.block_m, // index_t unit_size;
a.num_experts, // index_t num_experts;
a.topk, // index_t topk;
static_cast<ck_tile::long_index_t>(a.num_tokens) * a.stride_token *
o_data_bytes // index_t moe_buf_bytes;
auto a0 = fused_moesorting_args
{
a.topk_ids_ptr, // const void* p_topk_ids;
a.topk_weight_ptr, // const void* p_weights;
a.local_expert_mask_ptr, // const void* p_local_expert_mask;
a.local_tokens,
a.sorted_token_ids_ptr, // void* p_sorted_token_ids;
a.sorted_weight_ptr, // void* p_sorted_weights;
a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids;
a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad;
a.o_ptr, // void* p_moe_buf;
a.ws_ptr, // void* p_ws;
a.num_tokens, // index_t tokens;
a.block_m, // index_t unit_size;
a.num_experts, // index_t num_experts;
a.topk, // index_t topk;
#if MOE_SORTING_FMOE_2D_BUF
a.stride_token, o_data_bytes,
#else
static_cast<ck_tile::long_index_t>(a.num_tokens) *
a.stride_token* o_data_bytes // index_t moe_buf_bytes;
#endif
};
auto t1 = fused_moegemm_traits{t.prec_i,

View File

@@ -413,5 +413,6 @@ float fused_moesorting_mp(fused_moesorting_trait t,
int fused_moesorting_get_workspace_size(int tokens, int num_experts, int topk)
{
return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk);
return ck_tile::moe_sorting_get_workspace_size(
tokens, num_experts, topk, 0 /*dispatch policy*/);
}

View File

@@ -399,7 +399,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
// if return zero, means no need workspace, can set moe_sorting_args.p_ws to nullptr
ck_tile::index_t workspace_size =
ck_tile::moe_sorting_get_workspace_size(tokens, experts, topk);
ck_tile::moe_sorting_get_workspace_size(tokens, experts, topk, 0 /*dispatch_policy*/);
ck_tile::DeviceMem moe_sorting_ws(workspace_size != 0 ? workspace_size : 0);
if(workspace_size != 0)
moe_sorting_ws.SetZero(); // note, clear here!!!!

View File

@@ -1,8 +1,149 @@
# Grouped CShuffle GEMM
# Grouped Gemm
This folder contains example for Grouped GEMM using ck_tile tile-programming implementation.
Grouped General Matrix Multiplication (Grouped GEMM) is a technique used in GPU computing and high-performance computing to batch together multiple independent GEMM operations (matrix multiplications) into a single kernel launch in order to improve performance and efficiency. This folder contains Grouped GEMM examples that use the ck_tile tile-programming implementation.
## build
## Quick Tour for New Users
The `Grouped GEMM` operators are versions of GEMM that run multiple GEMM operations within a single kernel call. Each GEMM operation performs a matrix multiplication. Unlike regular batched GEMM operations where both matrices must be of the same size and have the same configuration, Grouped GEMM operations can take matrices with different sizes and configurations, making them more flexible for diverse workloads.
Let's now break the example into the following parts: parsing arguments, preparing host and device buffers, preparing data, invoking GEMM, and building the example, while explaining each function.
### Parsing Arguments
The example takes three arguments: `group_count`, `repeat`, and `warmup`:
- `group_count`: the number of GEMM operations in the group,
- `repeat`: the number of times to repeat the kernel for benchmarking
- `warmup`: the number of iterations before the actual kernel run time measure.
```cpp
// Example
const int group_count = arg_parser.get_int("group_count");
const int repeat = arg_parser.get_int("repeat");
const int warmup = arg_parser.get_int("warmup");
```
In the next step, the input parameters `Ms`, `Ns`, `Ks`, as well as the corresponding `stride_As`, `stride_Bs`, and `stride_Cs` are either provided from the comand line or generated by default. Since one or more input data sets are expected for `A` and `B`, each parameter is stored in a `std::vector`. The size of the `vector` is defined by `group_count`.
```cpp
// Example
std::vector<ck_tile::index_t> Ms = arg_parser.get_int_vec("Ms");
std::vector<ck_tile::index_t> Ns = arg_parser.get_int_vec("Ns");
std::vector<ck_tile::index_t> Ks = arg_parser.get_int_vec("Ks");
std::vector<ck_tile::index_t> stride_As = arg_parser.get_int_vec("stride_As");
std::vector<ck_tile::index_t> stride_Bs = arg_parser.get_int_vec("stride_Bs");
std::vector<ck_tile::index_t> stride_Cs = arg_parser.get_int_vec("stride_Cs");
```
Where:
- `Ms` is the M dimension of each GEMM.
- `Ns` is the N dimension of each GEMM.
- `Ks` is the K dimension of each GEMM.
- `stride_As` is the stride values for matrix A.
- `stride_Bs` is the stride values for matrix B.
- `stride_Cs` is the stride values for matrix C.
### HostTensor and Device Memory Buffers (for CPU and GPU)
Each parameter `Ms`, `Ns`, `Ks`, `stride_As`, `stride_Bs` and `stride_Cs` contains values for more than one matrix, meaning different matrix sizes and strides can be used for different grouped GEMM computations.
The next step is to properly load the input values. For each input matrix, `A` and `B`, and for each output matrix, `C`, you need to create both `HostTensor` and `DeviceMemory`, where:
- `HostTensor` represents the matrix data on the host (CPU). It stores the data before they are transferred to the device for computation.
- `DeviceMemory` represents the matrix data on the device (GPU). This will store the data on the GPU for computation during the Grouped GEMM operation.
#### HostTensor Buffers (for CPU)
In the first step, create `HostTensor` for `A`, `B`, `C`. `HostTensor` allocates memory on the host (CPU) to store the matrices, initializing the memory with the appropriate dimensions and values to store the data. Below is an example code showing how to create HostTensors for those tensors:
```cpp
// Example
std::vector<ck_tile::HostTensor<ADataType>> a_m_k_tensors;
std::vector<ck_tile::HostTensor<BDataType>> b_k_n_tensors;
std::vector<ck_tile::HostTensor<CDataType>> c_m_n_tensors;
```
Where:
- `a_m_k_tensors` is the vector of `HostTensor` objects for matrix `A` (with dimensions `M × K`). Each tensor stores the data for single GEMM operation.
- `b_k_n_tensors` is the vector of `HostTensor` objects for matrix `B` (with dimensions `K × N`).
- `c_m_n_tensors` is the vector of `HostTensor` objects for matrix `C` (the output matrix with dimensions `M × N`).
The `std::vector` container is used for this purpose throughout. As mentioned above, the number of HostTensors is equal to `group_count`.
#### Device Memory Buffers (for GPU)
Now it's time to allocate memory on the device (GPU) and transfer the data from `HostTensor` to `DeviceMemory` for actual computation..
```cpp
// Example
std::vector<std::unique_ptr<ck_tile::DeviceMem>> a_m_k_dev_buf;
std::vector<std::unique_ptr<ck_tile::DeviceMem>> b_k_n_dev_buf;
std::vector<std::unique_ptr<ck_tile::DeviceMem>> c_m_n_dev_buf;
```
Where:
- `a_m_k_dev_buf` is the buffer used for storing matrix A on the GPU.
- `b_k_n_dev_buf` is the buffer used for storing matrix B on the GPU.
- `c_m_n_dev_buf` is the buffer used for storing the result matrix C on the GPU.
## Prepare data
In the next step, the input tensors are populated. A pseudorandom number generator, an existing distribution (e.g., `FillUniformDistribution`), or user data can be used to populate the tensors. Descriptors also need to be create for each input tensor.
Use `get_default_stride` to get the strides for A, B, and C. `get_default_stride` is a template function that calculates the default stride for a 2D array based on whether it is row-major or column-major. Template parameter determines whether the storage order is row-major (true) or column-major (false). The function takes four params `row`, `col`, `stride` and `bool_constant<is_row_major>`. If the stride is explicitly provided (`stride != 0`), the stride is returned as-is. If the stride is not provided (`stride == 0`), the function computes the default stride. For the Row-major order (`is_row_major == true`), the stride is set to the number of columns (col). For the column-major order (`is_row_major == false`), the stride is set to the number of rows (row). This function is useful when working with dynamically allocated 2D arrays, where the user may not specify the stride explicitly. It ensures a natural default stride based on the chosen storage order.
```cpp
// Example, API
template <bool is_row_major>
auto get_default_stride(std::size_t row, std::size_t col, std::size_t stride, bool_constant<is_row_major>) {
// code
}
```
Where:
- `is_row_major` is a bool template parameter that determines whether the storage order is row-major (true) or column-major (false).
- `row` is the number of rows in the matrix.
- `col` is the number of columns in the matrix.
- `stride` is the current stride (the distance between consecutive elements in memory).
- `bool_constant<is_row_major>` is a tag type that helps in differentiating behavior at compile-time.
Next host descriptors for each of the input tensors, A, B, and C are created. Use the `f_host_tensor_descriptor` function defined below. This function takes four parameters, row, col, stride, and layout, and returns a HostTensorDescriptor based on the specified layout.
```cpp
// Example for tensor A
ck_tile::HostTensor<ADataType>(f_host_tensor_descriptor(M, K, stride_As[i], a_layout)))
```
After creating the host_tensors, create `deviceMem` for each tensor `A`, `B`, and `C`, and then transfer the data to the device. The `get_element_space_size_in_bytes()` function is used to get the buffer size in bytes. Use `ToDevice()` to transfer data from the host to the device. The data that was previously generated (`a_m_k_tensors[i].data()`) is passed as a parameter to `ToDevice()`.
The final step before running the GEMM operation is to retrieve the pointers to the buffers of `A`, `B`, and `C` stored on the device using `->GetDeviceBuffer()` and pack them into a shared container. For example: `gemm_descs.push_back({p_a, p_b, p_c, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]})`, where `gemm_descs` is `std::vector<grouped_gemm_kargs> gemm_descs` ([Code](https://github.com/ROCm/composable_kernel/blob/develop/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc#L221)). The container should include values such as:
```cpp
struct GroupedGemmHostArgs
{
const void* a_ptr;
const void* b_ptr;
void* c_ptr;
index_t M;
index_t N;
index_t K;
index_t stride_A;
index_t stride_B;
index_t stride_C;
};
```
The data prepared in this way can be passed to the `invoke_gemm` function. This is a templated function that also takes three template parameters: `ALayout`, `BLayout`, and `CLayout`:
```cpp
// Example, API
template <typename ALayout, typename BLayout, typename CLayout, bool Persistent>
float invoke_gemm(int n_warmup,
int n_repeat,
int group_count,
const std::vector<grouped_gemm_kargs>& args)
```
`invoke_gemm` returns the run time in milliseconds. The workspace memory required for computation is allocated. Workspace memory on the GPU refers to temporary memory buffers allocated when some operations are run. This extra space is needed to hold GEMM descriptions. The following structure can be used to allocate workspace:
```cpp
// Example
ck_tile::DeviceMem gemm_workspace;
gemm_workspace.Realloc(GetWorkspaceSize(args));
```
Finally the arguments are passed to group_gemm and the kernel is launched.
```cpp
// API
template <typename ALayout, typename BLayout, typename CLayout>
float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s,
void* kargs_ptr)
```
All the necessary parameters are set, the tiling is computed, the GEMM pipeline and epilogue are prepared, and the GroupedGemmKernel is launched.
## Build
```
# in the root of ck_tile
mkdir build && cd build
@@ -16,10 +157,17 @@ This will result in an executable `build/bin/tile_example_grouped_gemm`
## example
```
args:
-a_layout Tensor A layout (default:R)
-b_layout Tensor B layout (default:R)
-c_layout Tensor C layout (default:R)
-v 0. No validation, 1. Validation on CPU
-warmup number of iterations before benchmark the kernel (default:10)
-repeat number of iterations to benchmark the kernel (default:100)
-Ms M dimensions - (Default: empty).
-Ns N dimensions - (Default: empty).
-Ks K dimensions - (Default: empty).
-stride_As Tensor A strides - (Default: empty).
-stride_Bs Tensor B strides - (Default: empty).
-stride_Cs Tensor C strides - (Default: empty).
-a_layout A tensor data layout - (Default: Row).
-b_layout B tensor data layout - (Default: Col).
-c_layout C tensor data layout - (Default: Row).
-validate 0. No validation, 1. Validation on CPU. (Default: 1).
-warmup Number of iterations before benchmark the kernel. (Default: 10).
-repeat Number of iterations to benchmark the kernel. (Default: 100).
-group_count Group count. (Default: 16).
```

View File

@@ -116,9 +116,12 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
GemmPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,

View File

@@ -3,6 +3,4 @@ add_executable(tile_example_flatmm_basic EXCLUDE_FROM_ALL flatmm_basic.cpp)
set(EXAMPLE_FLATMM_COMPILE_OPTIONS)
# list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
# list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -Wno-unused-variable -Wno-unused-parameter)
list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_16x16x32=1 -Wno-unused-local-typedef)
#list(APPEND EXAMPLE_FLATMM_COMPILE_OPTIONS -DUSING_MFMA_32x32x16=1 -Wno-unused-local-typedef)
target_compile_options(tile_example_flatmm_basic PRIVATE ${EXAMPLE_FLATMM_COMPILE_OPTIONS})

View File

@@ -13,48 +13,94 @@
#include "flatmm_basic.hpp"
#include "run_flatmm_example.inc"
template <typename ADataType,
template <typename FlatmmConfig,
typename ADataType,
typename BDataType,
typename DsDatatype,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_config& s)
typename DsLayout,
typename ELayout,
bool persistent,
typename CDEElementWise>
float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_config& s)
{
using FlatmmConfig = FlatmmConfig<ADataType>;
using CodegenFlatmmShape = ck_tile::TileFlatmmShape<
using CodegenFlatmmShape = ck_tile::TileGemmShape<
ck_tile::sequence<FlatmmConfig::M_Tile, FlatmmConfig::N_Tile, FlatmmConfig::K_Tile>,
ck_tile::sequence<FlatmmConfig::M_Warp, FlatmmConfig::N_Warp, FlatmmConfig::K_Warp>,
ck_tile::sequence<FlatmmConfig::M_Warp_Tile,
FlatmmConfig::N_Warp_Tile,
FlatmmConfig::K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenFlatmmShape>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<CodegenFlatmmShape,
FlatmmConfig::TileParitionerGroupNum,
FlatmmConfig::TileParitionerM01>;
using CodegenGemmTraits = ck_tile::TileGemmTraits<FlatmmConfig::kPadM,
FlatmmConfig::kPadN,
FlatmmConfig::kPadK,
ALayout,
BLayout,
CLayout>;
using CodegenPipelineProblem = ck_tile::GemmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenFlatmmShape,
CodegenGemmTraits>;
const auto Run = [&](const auto memory_operation_) {
using Traits = ck_tile::TileGemmTraits<FlatmmConfig::kPadM,
FlatmmConfig::kPadN,
FlatmmConfig::kPadK,
ALayout,
BLayout,
ELayout,
FlatmmConfig::NumWaveGroups>;
using CodegenGemmTraits = ck_tile::TileGemmUniversalTraits<FlatmmConfig::kPadM,
FlatmmConfig::kPadN,
FlatmmConfig::kPadK,
FlatmmConfig::DoubleSmemBuffer,
ALayout,
BLayout,
ELayout,
FlatmmConfig::TransposeC,
FlatmmConfig::UseStructuredSparsity,
persistent,
FlatmmConfig::NumWaveGroups,
true>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenFlatmmShape, Traits>;
using BaseGemmPipeline = ck_tile::BaseFlatmmPipelineAGmemBGmemCRegV1<GemmPipelineProblem>;
const ck_tile::index_t k_grain = args.k_batch * FlatmmConfig::K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * FlatmmConfig::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto scheduler = FlatmmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using CodegenPipelineProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenFlatmmShape,
CodegenGemmTraits,
scheduler,
has_hot_loop_v,
tail_number_v>;
using CodegenFlatmmPipeline =
ck_tile::FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
DsDatatype,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
DsLayout,
ELayout,
CDEElementWise,
CodegenPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
@@ -64,11 +110,8 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
FlatmmConfig::N_Warp_Tile,
FlatmmConfig::K_Warp_Tile,
CodegenPipelineProblem::TransposeC,
memory_operation>>;
using CodegenFlatmmPolicy = ck_tile::UniversalFlatmmPipelineAgBgCrPolicy;
using CodegenFlatmmPipeline =
ck_tile::FlatmmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem, CodegenFlatmmPolicy>;
memory_operation,
FlatmmConfig::NumWaveGroups>>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
@@ -86,14 +129,15 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args:" << CodegenFlatmmShape::GetName()
<< CodegenPipelineProblem::GetName() << " grid: {" << grids.x << ", "
<< grids.y << ", " << grids.z << "}"
std::cout << "Launching kernel with args:" << CodegenFlatmmShape::GetName() << "\n"
<< "Shape: " << CodegenFlatmmShape::GetName() << "\n"
<< "problem: " << CodegenPipelineProblem::GetName() << "\n"
<< "pipeline: " << CodegenFlatmmPipeline::GetName() << "\n"
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
float ave_time{0};
if(s.flush_cache_)
{
std::cout << "Flushing cache..." << std::endl;
@@ -111,7 +155,7 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
kargs.a_ptr, kargs.b_shuffle_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
rotating_mem.Print();
auto run_flush_cache = [&]() {
@@ -122,7 +166,7 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
// clear c mem
if(args.k_batch > 1)
hipGetErrorString(hipMemsetAsync(
args.c_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
};
ave_time = ck_tile::launch_kernel_preprocess(
s,
@@ -139,18 +183,28 @@ float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_con
}
return ave_time;
};
if(args.k_batch == 1)
{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(args.k_batch == 1)
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
return ave_time;
}
template <template <typename PreType> typename FlatmmConfig>
int run_flatmm_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
@@ -163,24 +217,27 @@ int run_flatmm_example(int argc, char* argv[])
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
if(a_layout == "R" && b_layout == "C")
{
if(data_type == "fp16")
{
run_flatmm_example_with_layouts<ck_tile::half_t>(argc, argv, Row{}, Col{}, Row{});
run_flatmm_example_with_layouts<ck_tile::half_t, FlatmmConfig<ck_tile::half_t>>(
argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "bf16")
{
run_flatmm_example_with_layouts<ck_tile::bf16_t>(argc, argv, Row{}, Col{}, Row{});
run_flatmm_example_with_layouts<ck_tile::bf16_t, FlatmmConfig<ck_tile::bf16_t>>(
argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "fp8")
{
run_flatmm_example_with_layouts<ck_tile::fp8_t>(argc, argv, Row{}, Col{}, Row{});
run_flatmm_example_with_layouts<ck_tile::fp8_t, FlatmmConfig<ck_tile::fp8_t>>(
argc, argv, Row{}, Col{}, Row{});
}
else if(data_type == "bf8")
{
run_flatmm_example_with_layouts<ck_tile::bf8_t>(argc, argv, Row{}, Col{}, Row{});
run_flatmm_example_with_layouts<ck_tile::bf8_t, FlatmmConfig<ck_tile::bf8_t>>(
argc, argv, Row{}, Col{}, Row{});
}
else
{
@@ -196,9 +253,29 @@ int run_flatmm_example(int argc, char* argv[])
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return EXIT_FAILURE;
try
{
return !run_flatmm_example(argc, argv);
int warp_tile = arg_parser.get_int("warp_tile");
if(warp_tile == 0)
{
return !run_flatmm_example<FlatmmConfig16>(argc, argv);
}
else if(warp_tile == 1)
{
return !run_flatmm_example<FlatmmConfig32>(argc, argv);
}
else if(warp_tile == 2)
{
return !run_flatmm_example<FlatmmConfig16_950>(argc, argv);
}
else
{
return !run_flatmm_example<FlatmmConfig32_950>(argc, argv);
}
}
catch(const std::runtime_error& e)
{

View File

@@ -12,26 +12,81 @@
#include "ck_tile/ops/flatmm.hpp"
#include "ck_tile/ops/gemm.hpp"
#define CK_TILE_PIPELINE_COMPUTE 1
#define CK_TILE_PIPELINE_MEMORY 2
// GEMM config with 32x132 warp tile
template <typename DataType>
struct FlatmmConfig32
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(DataType);
#ifndef CK_TILE_PIPELINE_DEFAULT
#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE
#endif
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrMem
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrMem
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Interwave
#elif(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_COMPUTE)
#define GEMM_PIPELINE ck_tile::GemmPipelineAgBgCrCompV3
#define UNIVERSAL_GEMM_PIPELINE ck_tile::BaseGemmPipelineAgBgCrCompV3
#define GEMM_PIPELINE_SCHEDULER ck_tile::GemmPipelineScheduler::Intrawave
#else
#error "unsupported CK_TILE_PIPELINE_DEFAULT value"
#endif
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 16 : 32;
template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType>
static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr int kBlockPerCu = 2;
static constexpr int TileParitionerGroupNum = 8;
static constexpr int TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool DoubleSmemBuffer = false;
};
template <typename DataType>
struct FlatmmConfig32_950 : public FlatmmConfig32<DataType>
{
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 16 : 64;
};
// GEMM config with 16x16 warp tile
template <typename DataType>
struct FlatmmConfig16
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(DataType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 32 : 64;
static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr int kBlockPerCu = 2;
static constexpr int TileParitionerGroupNum = 8;
static constexpr int TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool DoubleSmemBuffer = false;
};
template <typename DataType>
struct FlatmmConfig16_950 : public FlatmmConfig16<DataType>
{
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(DataType) == 2 ? 32 : 128;
};
template <typename ADataType>
struct GemmBasicTypeConfig;
template <>
@@ -109,43 +164,6 @@ struct is_8bit_type
{
};
template <typename ADataType>
struct FlatmmConfig
{
#if defined(USING_MFMA_16x16x32)
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = is_8bit_type<ADataType>::value ? 16 : 32;
static constexpr ck_tile::index_t N_Warp_Tile = is_8bit_type<ADataType>::value ? 16 : 32;
static constexpr ck_tile::index_t K_Warp_Tile = is_8bit_type<ADataType>::value ? 64 : 16;
#elif defined(USING_MFMA_32x32x16)
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 256;
static constexpr ck_tile::index_t K_Tile = 128;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 8;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = is_8bit_type<ADataType>::value ? 32 : 32;
static constexpr ck_tile::index_t N_Warp_Tile = is_8bit_type<ADataType>::value ? 32 : 32;
static constexpr ck_tile::index_t K_Warp_Tile = is_8bit_type<ADataType>::value ? 32 : 16;
#endif
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;
static constexpr int kBlockPerCu = 2;
};
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
@@ -163,8 +181,11 @@ auto create_args(int argc, char* argv[])
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value");
.insert("split_k", "1", "splitK value")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("warp_tile",
"0",
"0: 16x16, 1: 32x32, 2: 16x16x128 (950 only), 3: 32x32x64 (950 only)");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
@@ -174,7 +195,8 @@ template <typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename FlatmmConfig,
typename ALayout,
typename BLayout,
typename CLayout>
float flatmm_calc(const ck_tile::FlatmmHostArgs& args, const ck_tile::stream_config& s);
float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_config& s);

View File

@@ -18,6 +18,10 @@ constexpr const char* DataTypeToString()
{
return "bf8";
}
else if constexpr(std::is_same_v<T, ck_tile::bf16_t>)
{
return "bf16";
}
else
{
return "unknown";
@@ -69,13 +73,31 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
template <typename ADataType,
template <typename FlatmmConfig,
typename ADataType,
typename BDataType,
typename DsDatatype,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout>
typename DsLayout,
typename ELayout,
bool persistent,
typename CDEElementWise>
float flatmm_calc(const ck_tile::FlatmmHostArgs<>& args, const ck_tile::stream_config& s);
template <typename FlatmmConfig,
typename ADataType,
typename BDataType,
typename DsDatatype,
typename AccDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
ck_tile::DeviceMem& b_shuffle_dev_buf,
ck_tile::DeviceMem& c_dev_buf,
@@ -89,22 +111,32 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
int n_warmup,
int n_repeat)
{
ck_tile::FlatmmHostArgs args;
args.a_ptr = a_dev_buf.GetDeviceBuffer();
args.b_shuffle_ptr = b_shuffle_dev_buf.GetDeviceBuffer();
args.c_ptr = c_dev_buf.GetDeviceBuffer();
ck_tile::FlatmmHostArgs<> args = {a_dev_buf.GetDeviceBuffer(),
b_shuffle_dev_buf.GetDeviceBuffer(),
{},
c_dev_buf.GetDeviceBuffer(),
kbatch,
M,
N,
K,
stride_A,
stride_B,
{},
stride_C};
args.k_batch = kbatch;
args.M = M;
args.N = N;
args.K = K;
args.stride_A = stride_A;
args.stride_B = stride_B;
args.stride_C = stride_C;
float ave_time =
flatmm_calc<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
float ave_time = flatmm_calc<FlatmmConfig,
ADataType,
BDataType,
DsDatatype,
AccDataType,
CDataType,
ALayout,
BLayout,
DsLayout,
CLayout,
false,
CDEElementWise>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte =
@@ -120,7 +152,11 @@ float invoke_flatmm(ck_tile::DeviceMem& a_dev_buf,
return ave_time;
}
template <typename PrecType, typename ALayout, typename BLayout, typename CLayout>
template <typename PrecType,
typename FlatmmConfig,
typename ALayout,
typename BLayout,
typename CLayout>
int run_flatmm_example_with_layouts(int argc,
char* argv[],
const ALayout a_layout = ALayout{},
@@ -131,11 +167,10 @@ int run_flatmm_example_with_layouts(int argc,
if(!result)
return -1;
using ADataType = typename GemmBasicTypeConfig<PrecType>::ADataType;
using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType;
using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType;
using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
using FlatmmConfig = FlatmmConfig<ADataType>;
using ADataType = typename GemmBasicTypeConfig<PrecType>::ADataType;
using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType;
using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType;
using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
@@ -145,10 +180,11 @@ int run_flatmm_example_with_layouts(int argc,
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat");
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat");
ck_tile::index_t init_method = arg_parser.get_int("init");
// persistent not added
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
@@ -162,8 +198,26 @@ int run_flatmm_example_with_layouts(int argc,
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
// TODO: add different init types
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_origin_host);
if(init_method == 0)
{
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_origin_host);
}
else if(init_method == 1)
{
ck_tile::FillMonotonicSeq<ADataType>{}(a_host);
ck_tile::FillMonotonicSeq<BDataType>{}(b_origin_host);
}
else if(init_method == 2)
{
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_host);
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_origin_host);
}
else
{
a_host.SetZero();
b_origin_host.SetZero();
}
ck_tile::DeviceMem a_dev_buf(a_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_dev_buf(c_rslt_host.get_element_space_size_in_bytes());
@@ -173,23 +227,30 @@ int run_flatmm_example_with_layouts(int argc,
// do pre-shuffle
ck_tile::HostTensor<BDataType> b_shuffle_host = shuffle_b<FlatmmConfig>(b_origin_host);
ck_tile::DeviceMem b_shuffle_dev_buf(b_shuffle_host.get_element_space_size_in_bytes());
b_shuffle_dev_buf.ToDevice(b_shuffle_host.data());
invoke_flatmm<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
a_dev_buf,
b_shuffle_dev_buf,
c_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
kbatch,
n_warmup,
n_repeat);
invoke_flatmm<FlatmmConfig,
ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ALayout,
BLayout,
ck_tile::tuple<>,
CLayout>(a_dev_buf,
b_shuffle_dev_buf,
c_dev_buf,
M,
N,
K,
stride_A,
stride_B,
stride_C,
kbatch,
n_warmup,
n_repeat);
c_dev_buf.FromDevice(c_rslt_host.data());
bool pass = true;

View File

@@ -1 +1,6 @@
add_executable(tile_example_gemm_multi_d_fp16 EXCLUDE_FROM_ALL gemm_multi_d_fp16.cpp)
set(EXAMPLE_GEMM_COMPILE_OPTIONS)
if(CK_USE_OCP_FP8)
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
endif()
target_compile_options(tile_example_gemm_multi_d_fp16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})

View File

@@ -0,0 +1,4 @@
add_executable(tile_example_grouped_conv_fwd EXCLUDE_FROM_ALL grouped_convolution_forward.cpp)
set(EXAMPLE_CONV_COMPILE_OPTIONS)
list(APPEND EXAMPLE_CONV_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
target_compile_options(tile_example_grouped_conv_fwd PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})

View File

@@ -0,0 +1,207 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <hip/hip_runtime.h>
#include <cstring>
#include <iostream>
#include <ostream>
#include <string>
#include <tuple>
#include "ck_tile/host.hpp"
#include "grouped_convolution_utils.hpp"
template <ck_tile::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename AccDataType,
typename OutDataType,
typename InLayout,
typename WeiLayout,
typename OutLayout,
typename DsDataType = ck_tile::tuple<>,
typename DsLayout = ck_tile::tuple<>,
typename CDEElementWise = ck_tile::element_wise::PassThrough>
float grouped_conv_fwd(const ck_tile::GroupedConvHostArgs& args, const ck_tile::stream_config& s)
{
constexpr int kBlockPerCu = 1;
constexpr ck_tile::index_t M_Tile = 64;
constexpr ck_tile::index_t N_Tile = 64;
constexpr ck_tile::index_t K_Tile = 64;
constexpr ck_tile::index_t M_Warp = 2;
constexpr ck_tile::index_t N_Warp = 2;
constexpr ck_tile::index_t K_Warp = 1;
constexpr ck_tile::index_t M_Warp_Tile = 32;
constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 16;
constexpr ck_tile::index_t VectorSizeA = 8;
constexpr ck_tile::index_t VectorSizeB = 8;
constexpr ck_tile::index_t VectorSizeC = 8;
// Implicit GEMM Traits
using CodegenShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
constexpr auto ConvSpec = ck_tile::ConvolutionSpecialization::Default;
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenShape>;
using GroupedConvTraitsType =
ck_tile::GroupedConvTraits<NDimSpatial, ConvSpec, InLayout, WeiLayout, DsLayout, OutLayout>;
using CodegenPipelineProblem =
ck_tile::GemmPipelineProblem<InDataType,
WeiDataType,
AccDataType,
CodegenShape,
typename GroupedConvTraitsType::GroupedConvImplicitGemmTraits,
InDataType,
true,
VectorSizeA,
VectorSizeB>;
using CodegenPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
const auto Run = [&](const auto memory_operation_) {
constexpr auto memory_operation = memory_operation_.value;
using ConvEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<InDataType,
WeiDataType,
DsDataType,
AccDataType,
OutDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
ck_tile::tensor_layout::gemm::RowMajor,
CDEElementWise,
CodegenPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
CodegenPipelineProblem::TransposeC,
memory_operation,
1,
true,
VectorSizeC>>;
using Kernel = ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraitsType,
TilePartitioner,
CodegenPipeline,
ConvEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args);
constexpr dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << CodegenShape::GetName() << '\n'
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
<< "pipeline: " << CodegenPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< '\n'
<< "Vector size A: " << CodegenPipeline::GetVectorSizeA()
<< ", Vector size B: " << CodegenPipeline::GetVectorSizeB()
<< ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl;
}
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
#include "run_grouped_convolution_example.inc"
template <typename InPrecType, typename WeiPrecType = InPrecType, typename OutPrecType = InPrecType>
int run_grouped_conv_fwd_example_prec_type(
std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[])
{
using NWGC = ck_tile::tensor_layout::convolution::NWGC;
using NHWGC = ck_tile::tensor_layout::convolution::NHWGC;
using NDHWGC = ck_tile::tensor_layout::convolution::NDHWGC;
using GKXC = ck_tile::tensor_layout::convolution::GKXC;
using GKYXC = ck_tile::tensor_layout::convolution::GKYXC;
using GKZYXC = ck_tile::tensor_layout::convolution::GKZYXC;
using NWGK = ck_tile::tensor_layout::convolution::NWGK;
using NHWGK = ck_tile::tensor_layout::convolution::NHWGK;
using NDHWGK = ck_tile::tensor_layout::convolution::NDHWGK;
if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK")
{
return run_grouped_conv_fwd_example_with_layouts<ck_tile::number<1>{},
InPrecType,
WeiPrecType,
OutPrecType>(
argc, argv, NWGC{}, GKXC{}, NWGK{});
}
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
{
return run_grouped_conv_fwd_example_with_layouts<ck_tile::number<2>{},
InPrecType,
WeiPrecType,
OutPrecType>(
argc, argv, NHWGC{}, GKYXC{}, NHWGK{});
}
else if(in_layout == "NDHWGC" && wei_layout == "GKZYXC" && out_layout == "GKZYXC")
{
return run_grouped_conv_fwd_example_with_layouts<ck_tile::number<3>{},
InPrecType,
WeiPrecType,
OutPrecType>(
argc, argv, NDHWGC{}, GKZYXC{}, NDHWGK{});
}
else
{
throw std::runtime_error("Unsupported memory layout!");
}
}
int run_grouped_conv_fwd_example(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
std::string data_type = arg_parser.get_str("prec");
std::string in_layout = arg_parser.get_str("in_layout");
std::string wei_layout = arg_parser.get_str("weight_layout");
std::string out_layout = arg_parser.get_str("out_layout");
if(data_type == "fp16")
{
return run_grouped_conv_fwd_example_prec_type<ck_tile::half_t>(
in_layout, wei_layout, out_layout, argc, argv);
}
else if(data_type == "bf16")
{
return run_grouped_conv_fwd_example_prec_type<ck_tile::bf16_t>(
in_layout, wei_layout, out_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported data type for this operation !!!");
}
}
int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_example(argc, argv); }

View File

@@ -0,0 +1,108 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/grouped_convolution.hpp"
ck_tile::index_t fill_spatial_dimensions(std::vector<ck_tile::index_t>& filter_spatial_lengths,
std::vector<ck_tile::index_t>& image_spatial_lengths,
std::vector<ck_tile::index_t>& strides,
std::vector<ck_tile::index_t>& dilations,
std::vector<ck_tile::index_t>& lpads,
std::vector<ck_tile::index_t>& rpads,
ck_tile::ArgParser& arg_parser)
{
constexpr ck_tile::index_t non_sp_dims = 3;
const ck_tile::index_t n_dim_sp = arg_parser.get_str("in_layout").size() - non_sp_dims;
if(!(n_dim_sp >= 1 && n_dim_sp <= 3))
{
throw std::runtime_error("Wrong layout!\n");
}
if(n_dim_sp == 3)
{
filter_spatial_lengths.push_back(arg_parser.get_int("z"));
image_spatial_lengths.push_back(arg_parser.get_int("d"));
strides.push_back(arg_parser.get_int("stride_d"));
dilations.push_back(arg_parser.get_int("dilation_d"));
lpads.push_back(arg_parser.get_int("lpad_d"));
rpads.push_back(arg_parser.get_int("rpad_d"));
}
if(n_dim_sp >= 2)
{
filter_spatial_lengths.push_back(arg_parser.get_int("y"));
image_spatial_lengths.push_back(arg_parser.get_int("h"));
strides.push_back(arg_parser.get_int("stride_h"));
dilations.push_back(arg_parser.get_int("dilation_h"));
lpads.push_back(arg_parser.get_int("lpad_h"));
rpads.push_back(arg_parser.get_int("rpad_h"));
}
filter_spatial_lengths.push_back(arg_parser.get_int("x"));
image_spatial_lengths.push_back(arg_parser.get_int("w"));
strides.push_back(arg_parser.get_int("stride_w"));
dilations.push_back(arg_parser.get_int("dilation_w"));
lpads.push_back(arg_parser.get_int("lpad_w"));
rpads.push_back(arg_parser.get_int("rpad_w"));
return n_dim_sp;
}
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("g", "2", "group dimension")
.insert("n", "32", "n dimension")
.insert("k", "32", "k dimension")
.insert("c", "32", "c dimension")
.insert("d", "64", "d dimension")
.insert("h", "64", "h dimension")
.insert("w", "64", "w dimension")
.insert("z", "4", "z dimension")
.insert("y", "4", "y dimension")
.insert("x", "4", "x dimension")
.insert("stride_d", "1", "d stride")
.insert("stride_h", "1", "h stride")
.insert("stride_w", "1", "w stride")
.insert("dilation_d", "1", "d dilation")
.insert("dilation_h", "1", "h dilation")
.insert("dilation_w", "1", "w dilation")
.insert("lpad_d", "0", "left pad for d dimension")
.insert("lpad_h", "0", "left pad for h dimension")
.insert("lpad_w", "0", "left pad for w dimension")
.insert("rpad_d", "0", "right pad for d dimension")
.insert("rpad_h", "0", "right pad for h dimension")
.insert("rpad_w", "0", "right pad for w dimension")
.insert("in_layout", "NHWGC", "Input image layout - NHWGC by default")
.insert("weight_layout", "GKYXC", "Weight layout - GKYXC by default")
.insert("out_layout", "NHWGK", "Output image layout - NHWGK by default")
.insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
// host API
float grouped_conv_fwd(const ck_tile::GroupedConvHostArgs& args, const ck_tile::stream_config& s);

View File

@@ -0,0 +1,206 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
template <typename InDataType, typename WeiDataType, typename AccDataType, typename OutDataType>
auto calculate_rtol_atol(const ck_tile::index_t GemmK,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(InDataType) < sizeof(WeiDataType), InDataType, WeiDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, OutDataType, AccDataType>(
ck_tile::integer_divide_ceil(GemmK, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, OutDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(GemmK, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<OutDataType, OutDataType, OutDataType>(kbatch);
const auto atol_split_k =
ck_tile::get_absolute_threshold<OutDataType, OutDataType, OutDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
template <ck_tile::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename AccDataType,
typename OutDataType,
typename InLayout,
typename WeiLayout,
typename OutLayout>
float invoke_grouped_conv_fwd(ck_tile::GroupedConvHostArgs& args, int n_warmup, int n_repeat)
{
float ave_time = grouped_conv_fwd<NDimSpatial,
InDataType,
WeiDataType,
AccDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::size_t flop = args.GetFlops();
std::size_t num_byte = args.GetByte<InDataType, WeiDataType, OutDataType>();
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
<< std::endl;
return ave_time;
}
template <ck_tile::index_t NDimSpatial,
typename InDataType,
typename WeiDataType = InDataType,
typename OutDataType = InDataType,
typename InLayout,
typename WeiLayout,
typename OutLayout>
int run_grouped_conv_fwd_example_with_layouts(
int argc, char* argv[], const InLayout, const WeiLayout, const OutLayout)
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
using AccDataType = float;
std::vector<ck_tile::index_t> filter_spatial_lengths;
std::vector<ck_tile::index_t> image_spatial_lengths;
std::vector<ck_tile::index_t> strides;
std::vector<ck_tile::index_t> dilations;
std::vector<ck_tile::index_t> lpads;
std::vector<ck_tile::index_t> rpads;
const ck_tile::index_t num_dim_sp = fill_spatial_dimensions(filter_spatial_lengths,
image_spatial_lengths,
strides,
dilations,
lpads,
rpads,
arg_parser);
ck_tile::conv::ConvParam conv_param{num_dim_sp,
arg_parser.get_int("g"),
arg_parser.get_int("n"),
arg_parser.get_int("k"),
arg_parser.get_int("c"),
filter_spatial_lengths,
image_spatial_lengths,
strides,
dilations,
lpads,
rpads};
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat");
ck_tile::index_t init_method = arg_parser.get_int("init");
const auto in_g_n_c_wis_desc =
ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(conv_param);
const auto wei_g_k_c_xs_desc =
ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(conv_param);
const auto out_g_n_k_wos_desc =
ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(conv_param);
ck_tile::HostTensor<InDataType> input(in_g_n_c_wis_desc);
ck_tile::HostTensor<WeiDataType> weight(wei_g_k_c_xs_desc);
ck_tile::HostTensor<OutDataType> output(out_g_n_k_wos_desc);
if(init_method == 0)
{
ck_tile::FillUniformDistribution<InDataType>{-5.f, 5.f}(input);
ck_tile::FillUniformDistribution<WeiDataType>{-5.f, 5.f}(weight);
}
else if(init_method == 1)
{
ck_tile::FillMonotonicSeq<InDataType>{}(input);
ck_tile::FillMonotonicSeq<WeiDataType>{}(weight);
}
else if(init_method == 2)
{
ck_tile::FillUniformDistribution<InDataType>{1.f, 1.f}(input);
ck_tile::FillUniformDistribution<WeiDataType>{1.f, 1.f}(weight);
}
else
{
input.SetZero();
weight.SetZero();
}
ck_tile::DeviceMem input_dev_buf(input.get_element_space_size_in_bytes());
ck_tile::DeviceMem weight_dev_buf(weight.get_element_space_size_in_bytes());
ck_tile::DeviceMem output_dev_buf(output.get_element_space_size_in_bytes());
input_dev_buf.ToDevice(input.data());
weight_dev_buf.ToDevice(weight.data());
output_dev_buf.SetZero();
ck_tile::GroupedConvHostArgs args(conv_param,
input_dev_buf.GetDeviceBuffer(),
weight_dev_buf.GetDeviceBuffer(),
{},
output_dev_buf.GetDeviceBuffer(),
kbatch);
std::cout << "Run Grouped Conv Fwd kernel" << std::endl;
std::cout << "input: " << input.mDesc << std::endl;
std::cout << "weight: " << weight.mDesc << std::endl;
std::cout << "output: " << output.mDesc << std::endl;
invoke_grouped_conv_fwd<NDimSpatial,
InDataType,
WeiDataType,
AccDataType,
OutDataType,
InLayout,
WeiLayout,
OutLayout>(args, n_warmup, n_repeat);
output_dev_buf.FromDevice(output.data());
bool pass = true;
if(arg_parser.get_int("v") == 1)
{
ck_tile::HostTensor<OutDataType> output_host_ref(out_g_n_k_wos_desc);
output_host_ref.SetZero();
ck_tile::reference_grouped_conv_fwd<NDimSpatial, InDataType, WeiDataType, OutDataType>(
input,
weight,
output_host_ref,
conv_param.conv_filter_strides_,
conv_param.conv_filter_dilations_,
conv_param.input_left_pads_,
conv_param.input_right_pads_);
const ck_tile::index_t GemmK = weight.get_element_size() / (conv_param.G_ * conv_param.K_);
const float max_accumulated_value =
*std::max_element(output_host_ref.mData.begin(), output_host_ref.mData.end());
const auto rtol_atol =
calculate_rtol_atol<InDataType, WeiDataType, AccDataType, OutDataType>(
GemmK, kbatch, max_accumulated_value);
pass = ck_tile::check_err(output,
output_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
}
else if(arg_parser.get_int("v") == 2)
{
throw std::runtime_error("Unsupported gpu verification !!!");
}
return pass;
}

View File

@@ -1,4 +0,0 @@
add_executable(test_copy_kernel EXCLUDE_FROM_ALL test_copy.cpp)
target_compile_options(test_copy_kernel PRIVATE
-mllvm -enable-noalias-to-md-conversion=0
)

View File

@@ -1,31 +0,0 @@
# Copy Kernel
This folder contains basic setup code designed to provide a platform for novice
CK_Tile kernel developers to test basic functionality with minimal additional
code compared to the functional code. Sample functional code for a simple
tile distribution for DRAM window and LDS window are provided and data is moved
from DRAM to registers, registers to LDS, LDS to registers and finally data
is moved to output DRAM window for a simple copy operation.
## build
```
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture
# (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
# Make the copy kernel executable
make test_copy -j
```
This will result in an executable `build/bin/test_copy_kernel`
## example
```
args:
-m input matrix rows. (default 64)
-n input matrix cols. (default 8)
-id warp to use for computation. (default 0)
-v validation flag to check device results. (default 1)
-prec datatype precision to use. (default fp16)
-warmup no. of warmup iterations. (default 50)
-repeat no. of iterations for kernel execution time. (default 100)
```

View File

@@ -1,117 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/host.hpp"
#include <cstring>
#include "test_copy.hpp"
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "64", "m dimension")
.insert("n", "8", "n dimension")
.insert("id", "0", "warp to use")
.insert("v", "1", "cpu validation or not")
.insert("prec", "fp16", "precision")
.insert("warmup", "50", "cold iter")
.insert("repeat", "100", "hot iter");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser)
{
using XDataType = DataType;
using YDataType = DataType;
ck_tile::index_t m = arg_parser.get_int("m");
ck_tile::index_t n = arg_parser.get_int("n");
ck_tile::index_t warp_id = arg_parser.get_int("id");
int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
ck_tile::HostTensor<XDataType> x_host({m, n});
ck_tile::HostTensor<YDataType> y_host_ref({m, n});
ck_tile::HostTensor<YDataType> y_host_dev({m, n});
// ck_tile::FillConstant<XDataType>{1.f}(x_host);
ck_tile::half_t value = 1;
for(int i = 0; i < m; i++)
{
value = 1;
for(int j = 0; j < n; j++)
{
x_host(i, j) = value++;
}
}
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes());
x_buf.ToDevice(x_host.data());
using BlockWaves = ck_tile::sequence<2, 1>;
using BlockTile = ck_tile::sequence<64, 8>;
using WaveTile = ck_tile::sequence<64, 8>;
using Vector = ck_tile::sequence<1, 4>;
ck_tile::index_t kGridSize = (m / BlockTile::at(ck_tile::number<0>{}));
std::cout << "grid size " << kGridSize << std::endl;
using Shape = ck_tile::TileCopyShape<BlockWaves, BlockTile, WaveTile, Vector>;
using Problem = ck_tile::TileCopyProblem<XDataType, Shape>;
using Kernel = ck_tile::TileCopy<Problem>;
constexpr ck_tile::index_t kBlockSize = 128;
constexpr ck_tile::index_t kBlockPerCu = 1;
std::cout << "block size " << kBlockSize << std::endl;
std::cout << "warp SIze " << ck_tile::get_warp_size() << std::endl;
std::cout << "warps per block _M " << Shape::WarpPerBlock_M << " " << Shape::WarpPerBlock_N
<< std::endl;
std::cout << "Block waves: " << BlockWaves::at(ck_tile::number<0>{}) << " "
<< BlockWaves::at(ck_tile::number<1>{}) << std::endl;
std::cout << " Wave Groups: " << Shape::WaveGroups << std::endl;
float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
Kernel{},
kGridSize,
kBlockSize,
0,
static_cast<XDataType*>(x_buf.GetDeviceBuffer()),
static_cast<YDataType*>(y_buf.GetDeviceBuffer()),
m,
n,
warp_id));
std::size_t num_btype = sizeof(XDataType) * m * n + sizeof(YDataType) * m;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << gb_per_sec << " GB/s" << std::endl;
bool pass = true;
if(do_validation)
{
// reference
y_buf.FromDevice(y_host_dev.mData.data());
pass = ck_tile::check_err(y_host_dev, x_host);
std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl;
}
return pass;
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
const std::string data_type = arg_parser.get_str("prec");
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
}

View File

@@ -1,178 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host/kernel_launch.hpp"
namespace ck_tile {
template <typename BlockWaves, // num warps along seq<M, N>
typename BlockTile, // block size, seq<M, N>
typename WaveTile, // warp size, seq<M, N>
typename Vector> // contiguous elements(vector size) along seq<M, N>
struct TileCopyShape
{
// We split Workgroup waves into two specialized groups.
// One for reading data from global -> LDS, the other is doing reduction
static constexpr index_t WaveGroups = 2;
static constexpr index_t MWarps = BlockWaves::at(number<0>{});
static constexpr index_t NWarps = BlockWaves::at(number<0>{});
static constexpr index_t Block_M = BlockTile::at(number<0>{});
static constexpr index_t Block_N = BlockTile::at(number<1>{});
static constexpr index_t Warp_M = WaveTile::at(number<0>{});
static constexpr index_t Warp_N = WaveTile::at(number<1>{});
static constexpr index_t Vector_M = Vector::at(number<0>{});
static constexpr index_t Vector_N = Vector::at(number<1>{});
static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M;
static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N;
static constexpr index_t WarpPerBlock_M =
integer_divide_ceil(BlockWaves::at(number<0>{}), WaveGroups);
static constexpr index_t WarpPerBlock_N =
integer_divide_ceil(BlockWaves::at(number<1>{}), WaveGroups);
static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M);
static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N);
static constexpr index_t WaveNum = reduce_on_sequence(BlockWaves{}, multiplies{}, number<1>{});
static constexpr index_t BlockSize = get_warp_size() * WaveNum;
static constexpr index_t WaveGroupSize = WaveNum / WaveGroups;
static_assert(WaveGroupSize == WarpPerBlock_M * WarpPerBlock_N, "Inconsisten wave group size!");
};
template <typename XDataType_, typename BlockShape_>
struct TileCopyProblem
{
using XDataType = remove_cvref_t<XDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
};
template <typename Problem_>
struct TileCopy
{
using Problem = ck_tile::remove_cvref_t<Problem_>;
using XDataType = typename Problem::XDataType;
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeDRAMDistribution()
{
using S = typename Problem::BlockShape;
constexpr index_t warp_size = get_warp_size();
constexpr index_t X0 = S::ThreadPerWarp_N; // threads needed along N dimension, fastest
// changing with given vector size.
constexpr index_t X1 =
S::Vector_N; // no. of elements along N dimensions to be read by each thread.
constexpr index_t Y0 =
S::WaveNum / S::WaveGroups; // no. of active warps working in this thread block.
constexpr index_t Y1 = warp_size / X0; // no. of threads in a warp needed along M dimension.
constexpr index_t Y2 =
S::Warp_M /
(Y1 *
Y0); // no. of iterations each warp needs to perform to cover the entire tile window.
constexpr auto outer_encoding =
tile_distribution_encoding<sequence<Y0>,
tuple<sequence<Y1, Y2>, sequence<X0, X1>>,
tuple<sequence<0>, sequence<1, 2>>,
tuple<sequence<0>, sequence<0, 0>>,
sequence<1, 2>,
sequence<1, 1>>{};
return make_static_tile_distribution(outer_encoding);
}
CK_TILE_DEVICE void
operator()(const XDataType* p_x, XDataType* p_y, index_t M, index_t N, index_t warp_id) const
{
using S = typename Problem::BlockShape;
// LDS Data.
__shared__ XDataType x_lds[number<S::Block_M>{} * number<S::Block_N>{}];
XDataType* __restrict__ p_x_lds = static_cast<XDataType*>(x_lds);
const auto x_lds_desc = make_naive_tensor_descriptor(
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}, number<S::Vector_N>{}),
make_tuple(number<S::Block_N>{}, number<S::Vector_N>{}, 1),
number<S::Vector_N>{},
number<1>{});
auto x_lds_block_desc = transform_tensor_descriptor(
x_lds_desc,
make_tuple(make_pass_through_transform(number<S::Block_M>{}),
make_merge_transform(
make_tuple(number<S::Block_N>{} / S::Vector_N, number<S::Vector_N>{}))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
auto x_lds_view = make_tensor_view<address_space_enum::lds>(p_x_lds, x_lds_block_desc);
auto x_block_lds_window =
make_tile_window(x_lds_view,
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
{0, 0},
MakeDRAMDistribution<Problem>());
auto x_block_lds_window_no_dist = make_tile_window(
x_lds_view, make_tuple(number<S::Block_M>{}, number<S::Block_N>{}), {0, 0});
// Input tensor
const auto iM = get_block_id() * S::Block_M;
const auto x_m_n = make_naive_tensor_view<address_space_enum::global>(
p_x, make_tuple(M, N), make_tuple(N, 1), number<S::Vector_N>{}, number<1>{});
auto x_block_window =
make_tile_window(x_m_n,
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
{iM, 0},
MakeDRAMDistribution<Problem>());
// Output tensor
const auto y_m = make_naive_tensor_view<address_space_enum::global>(
p_y, make_tuple(M, N), make_tuple(N, 1), number<S::Vector_N>{}, number<1>{});
auto y_block_window =
make_tile_window(y_m, make_tuple(number<S::Block_M>{}, number<S::Block_N>{}), {iM, 0});
// Programming logic
index_t num_n_tile_iteration =
__builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_N));
auto my_id = get_warp_id();
auto DramTileDist = x_block_window.get_tile_distribution();
using dram_reg_tile = decltype(make_static_distributed_tensor<XDataType>(DramTileDist));
for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
{
dram_reg_tile dram_tile;
if(my_id == warp_id)
{
// load from DRAM to registers
load_tile(dram_tile, x_block_window);
// store in lds
store_tile(x_block_lds_window_no_dist, dram_tile);
// read from lds to registers
load_tile(dram_tile, x_block_lds_window);
// store from registers to DRAM
store_tile(y_block_window, dram_tile);
}
__syncthreads();
move_tile_window(x_block_window, {0, S::Block_N});
move_tile_window(y_block_window, {0, S::Block_N});
}
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,9 @@
set(TARGET_NAME tile_example_transpose)
add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL transpose_example.cpp transpose_api.cpp)
target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
# list(APPEND EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
target_compile_options(tile_example_transpose PRIVATE ${EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS})

View File

@@ -0,0 +1,27 @@
# Batched Transpose
This folder contains example for transpose load for architecture gfx950. This transpose load has some constraints in input tile distribution.
## build
```
# in the root of ck_tile
mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
# Make the transpose executable
make tile_example_transpose -j
```
This will result in an executable `build/bin/tile_example_transpose`
## example
```
args:
-N input batch size (default:2)
-C input channel size. (default:64)
-H input height size. (default:1)
-W input width size. (default:64)
-v whether do CPU validation or not (default: 1)
-layout_in input tensor data layout - NCHW by default
-layout_out output tensor data layout - NHWC by default
-seed seed to be used, -1 means random every time (default:-1)
-k_name t to 1 will print kernel name (default:0)
```

View File

@@ -0,0 +1,120 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {
struct BatchedTransposeHostArgs
{
const void* p_input;
void* p_output;
index_t batch;
index_t height;
index_t width;
// index_t dim_blocks;
index_t dim_stride;
index_t dim_block_h;
index_t dim_block_w;
};
template <typename Pipeline_>
struct BatchedTransposeKernel
{
using Pipeline = remove_cvref_t<Pipeline_>;
using Problem = remove_cvref_t<typename Pipeline::Problem>;
using Type = typename Problem::DataType;
struct BatchedTransposeKargs
{
const void* p_input;
void* p_output;
index_t batch;
index_t height;
index_t width;
index_t dim_stride;
};
using Kargs = BatchedTransposeKargs;
using Hargs = BatchedTransposeHostArgs;
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
{
size_t grid_size_x = h.dim_block_w;
size_t grid_size_y = h.dim_block_h;
size_t grid_size_z = h.batch;
return dim3(grid_size_x, grid_size_y, grid_size_z);
}
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
{
Kargs k;
k.p_input = h.p_input;
k.p_output = h.p_output;
k.batch = h.batch;
k.height = h.height;
k.width = h.width;
k.dim_stride = h.dim_stride;
return k;
}
CK_TILE_HOST_DEVICE static constexpr auto BlockSize() { return Problem::kBlockSize; }
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
__shared__ char smem[Pipeline::GetSmemSize()];
static constexpr ck_tile::index_t kMPerBlock = Problem::kSecondSizePerBlock;
static constexpr ck_tile::index_t kNPerBlock = Problem::kLeadSizePerBlock;
const auto iDim = blockIdx.z;
const auto x_m_n = [&]() {
const auto x_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<const Type*>(kargs.p_input) + iDim * kargs.dim_stride,
make_tuple(kargs.height, kargs.width),
make_tuple(kargs.width, 1),
number<Pipeline::GetVectorSize()>{},
number<1>{});
return pad_tensor_view(x_dram_naive,
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
sequence<false, false>{});
}();
const auto iM = __builtin_amdgcn_readfirstlane(blockIdx.y * kMPerBlock);
const auto iN = __builtin_amdgcn_readfirstlane(blockIdx.x * kNPerBlock);
const auto y_n_m = [&]() {
const auto y_dram_naive = make_naive_tensor_view<address_space_enum::global>(
static_cast<Type*>(kargs.p_output) + iDim * kargs.dim_stride,
make_tuple(kargs.width, kargs.height),
make_tuple(kargs.height, 1),
number<Pipeline::GetVectorSize()>{},
number<1>{});
return pad_tensor_view(y_dram_naive,
make_tuple(number<kNPerBlock>{}, number<kMPerBlock>{}),
sequence<false, false>{});
}();
auto x_block_window = make_tile_window(
x_m_n,
make_tuple(number<kMPerBlock>{}, number<kNPerBlock>{}),
{static_cast<ck_tile::index_t>(iM), static_cast<ck_tile::index_t>(iN)});
auto y_block_window = make_tile_window(
y_n_m,
make_tuple(number<kNPerBlock>{}, number<kMPerBlock>{}),
{static_cast<ck_tile::index_t>(iN), static_cast<ck_tile::index_t>(iM)});
Pipeline{}(x_block_window, y_block_window, smem);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,149 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "transpose_policy.hpp"
namespace ck_tile {
template <typename Layout_, index_t kRow, index_t kCol>
struct TransposeTraits
{
static constexpr index_t kLeadDim = kCol;
static constexpr index_t kSecondDim = kRow;
};
template <index_t kRow, index_t kCol>
struct TransposeTraits<tensor_layout::gemm::ColumnMajor, kRow, kCol>
{
static constexpr index_t kLeadDim = kRow;
static constexpr index_t kSecondDim = kCol;
};
// supports 2D transpose which will store to lds, then use ds_read_b*_tr_b* instruction to get the
// transposed data; Layout in TransposePipelineProblem is the original layout of the data in the
// global memory
template <typename DataType_,
typename Layout_,
index_t kBlockSize_,
index_t kRowWarps_, // how many warps in row direction
index_t kColWarps_, // how many warps in col direction
index_t kRowPerBlock_, // row number per block
index_t kColPerBlock_, // col number per block
index_t kRowPerXdl_, // row number per xdl ops
index_t kColPerXdl_> // col number per xdl ops
struct TransposePipelineProblem
{
static_assert(kRowWarps_ * kColWarps_ * get_warp_size() == kBlockSize_,
"the block size is not correct!");
using DataType = remove_cvref_t<DataType_>;
using Layout = remove_cvref_t<Layout_>;
static constexpr index_t kBlockSize = kBlockSize_;
static constexpr index_t kLeadNumWarps =
TransposeTraits<Layout, kRowWarps_, kColWarps_>::kLeadDim;
static constexpr index_t kSecondNumWarps =
TransposeTraits<Layout, kRowWarps_, kColWarps_>::kSecondDim;
static constexpr index_t kLeadSizePerBlock =
TransposeTraits<Layout, kRowPerBlock_, kColPerBlock_>::kLeadDim;
static constexpr index_t kSecondSizePerBlock =
TransposeTraits<Layout, kRowPerBlock_, kColPerBlock_>::kSecondDim;
static constexpr index_t kLeadSizePerXdl =
TransposeTraits<Layout, kRowPerXdl_, kColPerXdl_>::kLeadDim;
static constexpr index_t kSecondSizePerXdl =
TransposeTraits<Layout, kRowPerXdl_, kColPerXdl_>::kSecondDim;
static constexpr index_t kQuadrantLeadDim = LaneGroupTransposeTraits<DataType>::kleadDim;
static constexpr index_t kQuadrantSecondDim = LaneGroupTransposeTraits<DataType>::ksecondDim;
static_assert(kLeadSizePerBlock % kLeadNumWarps == 0,
"block dim should be divided by warp dim!");
static_assert(kSecondSizePerBlock % kSecondNumWarps == 0,
"block dim should be divided by warp dim!");
// how many rows/cols implemented in one warp
static constexpr index_t kLeadSizePerWarp = kLeadSizePerBlock / kLeadNumWarps;
static constexpr index_t kSecondSizePerWarp = kSecondSizePerBlock / kSecondNumWarps;
static_assert(kLeadSizePerWarp % kLeadSizePerXdl == 0,
"warp dim should be divided by xdl dim!");
static_assert(kSecondSizePerWarp % kSecondSizePerXdl == 0,
"warp dim should be divided by xdl dim!");
// warp rows/cols is divided into xdl.
static constexpr index_t kLeadXdlNumPerWarp = kLeadSizePerWarp / kLeadSizePerXdl;
static constexpr index_t kSecondXdlNumPerWarp = kSecondSizePerWarp / kSecondSizePerXdl;
static_assert(kLeadSizePerXdl % kQuadrantLeadDim == 0,
"xdl dim should be divided by quad dim!");
static_assert(kSecondSizePerXdl % kQuadrantSecondDim == 0,
"xdl dim should be divided by quad dim!");
// xdl rows/cols is divided into quadrants.
static constexpr index_t kQuadNumPerLeadDim = kLeadSizePerXdl / kQuadrantLeadDim;
static constexpr index_t kQuadNumPerSecondDim = kSecondSizePerXdl / kQuadrantSecondDim;
static constexpr index_t kIterationsInSecondDim =
kQuadNumPerLeadDim * kQuadNumPerSecondDim * 16 / get_warp_size();
};
template <typename Problem_, typename Policy_ = TransposePolicy>
struct BlockTranspose
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using DataType = remove_cvref_t<typename Problem::DataType>;
using Layout = remove_cvref_t<typename Problem::Layout>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kLeadSizePerBlock = Problem::kLeadSizePerBlock;
static constexpr index_t kSecondSizePerBlock = Problem::kSecondSizePerBlock;
static constexpr index_t GetVectorSize() { return Policy::template GetVectorSize<Problem>(); }
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename InputTileWindow, typename OutputTileWindow>
CK_TILE_DEVICE void operator()(const InputTileWindow& input_window,
OutputTileWindow& output_window,
void* __restrict__ p_smem)
{
auto input_tile_window =
make_tile_window(input_window, Policy::template MakeInputDistribution<Problem>());
auto output_tile_window =
make_tile_window(output_window, Policy::template MakeOutputDistribution<Problem>());
DataType* p_lds_ptr = static_cast<DataType*>(p_smem);
constexpr auto in_lds_block_desc = Policy::template MakeLdsStoreBlockDescriptor<Problem>();
auto input_lds_block =
make_tensor_view<address_space_enum::lds>(p_lds_ptr, in_lds_block_desc);
constexpr auto out_lds_block_desc = Policy::template MakeLdsLoadBlockDescriptor<Problem>();
auto output_lds_block =
make_tensor_view<address_space_enum::lds>(p_lds_ptr, out_lds_block_desc);
auto copy_to_lds_window =
make_tile_window(input_lds_block,
make_tuple(number<kSecondSizePerBlock>{}, number<kLeadSizePerBlock>{}),
{0, 0});
auto load_from_lds_window =
make_tile_window(output_lds_block,
make_tuple(number<kSecondSizePerBlock>{}, number<kLeadSizePerBlock>{}),
{0, 0},
Policy::template MakeLdsLoadTileDistribution<Problem>());
auto x = load_tile(input_tile_window);
store_tile(copy_to_lds_window, x);
block_sync_lds();
auto y = load_tile_transpose(load_from_lds_window);
store_tile(output_tile_window, y);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,59 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "transpose_example.hpp"
#include <iostream>
template <typename ts_type,
ck_tile::index_t block_x,
ck_tile::index_t block_y,
ck_tile::index_t warp_x,
ck_tile::index_t warp_y>
float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_config& s)
{
uint32_t dim_block_h = (a.height + block_y - 1) / block_y;
uint32_t dim_block_w = (a.width + block_x - 1) / block_x;
uint32_t dim_stride = a.height * a.width;
a.dim_stride = dim_stride;
a.dim_block_h = dim_block_h;
a.dim_block_w = dim_block_w;
using ts_problem = ck_tile::TransposePipelineProblem<ts_type,
ck_tile::tensor_layout::gemm::RowMajor,
64,
1,
1,
block_y,
block_x,
warp_y,
warp_x>;
using ts_pipeline = ck_tile::BlockTranspose<ts_problem>;
using kernel = ck_tile::BatchedTransposeKernel<ts_pipeline>;
auto kargs = kernel::MakeKargs(a);
const dim3 grids = kernel::GridSize(a);
constexpr dim3 blocks = kernel::BlockSize();
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, 1>(kernel{}, grids, blocks, 0, kargs));
return ave_time;
}
float batched_transpose(batched_transpose_trait t,
batched_transpose_kargs a,
ck_tile::stream_config s)
{
if(t.type == "fp16")
{
return batched_transpose_dispatch<ck_tile::fp16_t, 16, 32, 16, 32>(a, s);
}
else if(t.type == "fp8")
{
return batched_transpose_dispatch<ck_tile::fp8_t, 16, 64, 16, 64>(a, s);
}
return -1;
}

View File

@@ -0,0 +1,257 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <vector>
#include <iostream>
#include <numeric>
#include <cassert>
#include <cstdlib>
#include <iostream>
#include <time.h>
#include <unordered_set>
#include "transpose_example.hpp"
#if 0
template <typename T>
void dump_host_tensor_4d(const ck_tile::HostTensor<T>& x)
{
auto len = x.get_lengths();
assert(len.size() == 4);
std::cout << "[";
for(size_t i = 0; i < len[0]; i++)
{
std::cout << i << ": [";
for(size_t j = 0; j < len[1]; j++)
{
std::cout << j << ": [";
for(size_t k = 0; k < len[2]; k++)
{
std::cout << k << ": [";
for(size_t v = 0; v < len[3]; v++)
{
if constexpr(std::is_same_v<T, ck_tile::fp16_t>)
{
auto m =
ck_tile::type_convert<float>(x(std::vector<std::size_t>{i, j, k, v}));
std::cout << m;
if(v != len[3] - 1)
std::cout << ",";
}
else
{
std::cout << x(std::vector<std::size_t>{i, j, k, v}) << " ";
}
}
std::cout << "]" << std::endl;
}
std::cout << "]" << std::endl;
}
std::cout << std::endl;
}
std::cout << "--------------------" << std::endl;
}
#endif
// different threshold for different dtype
template <typename DataType>
auto get_elimit(std::string /*init_method*/)
{
double rtol = 1e-3;
double atol = 1e-3;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::bf16_t>(std::string /*init_method*/)
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::fp8_t>(std::string init_method)
{
if(init_method == "ui" || init_method == "ni")
{
unsigned max_rounding_point_distance = 0;
double atol = 2e-3;
return ck_tile::make_tuple(max_rounding_point_distance, atol);
}
else
{
unsigned max_rounding_point_distance = 1;
double atol = 0.0625;
return ck_tile::make_tuple(max_rounding_point_distance, atol);
}
}
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("v", "1", "whether do CPU validation or not")
.insert("pr", "fp16", "input data type. fp16/fp32 (representing 8/16/32 bit data)")
.insert("N", "2", "input batch size. ")
.insert("C", "64", "input channel size.")
.insert("H", "1", "input height size.")
.insert("W", "64", "input width size. ")
.insert("layout_in", "NCHW", "input tensor data layout - NCHW by default")
.insert("layout_out", "NHWC", "output tensor data layout - NHWC by default ")
.insert("seed", "-1", "seed to be used, -1 means random every time")
.insert("kname", "0", "t to 1 will print kernel name");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename Type>
bool run_batched_transpose(ck_tile::ArgParser args)
{
int validate = args.get_int("v");
std::string prec = args.get_str("pr");
int N = args.get_int("N");
int C = args.get_int("C");
int H = args.get_int("H");
int W = args.get_int("W");
std::string layout_in = args.get_str("layout_in");
std::string layout_out = args.get_str("layout_out");
int seed = args.get_int("seed");
int dim_in[4], dim_out[4];
int stride_dim_in[4], stride_dim_out[4];
bool nchw2nhwc = layout_in == "NCHW" && layout_out == "NHWC";
bool nhwc2nchw = layout_in == "NHWC" && layout_out == "NCHW";
assert(nchw2nhwc != nhwc2nchw);
(void)nhwc2nchw;
dim_in[0] = N;
dim_in[1] = nchw2nhwc ? C : H;
dim_in[2] = nchw2nhwc ? H : W;
dim_in[3] = nchw2nhwc ? W : C;
dim_out[0] = N;
dim_out[1] = nchw2nhwc ? H : C;
dim_out[2] = nchw2nhwc ? W : H;
dim_out[3] = nchw2nhwc ? C : W;
stride_dim_in[0] = C * H * W;
stride_dim_in[1] = nchw2nhwc ? H * W : C * W;
stride_dim_in[2] = nchw2nhwc ? W : C;
stride_dim_in[3] = 1;
stride_dim_out[0] = C * H * W;
stride_dim_out[1] = nchw2nhwc ? C * W : H * W;
stride_dim_out[2] = nchw2nhwc ? C : W;
stride_dim_out[3] = 1;
if(seed < 0)
{
seed = std::time(nullptr);
}
ck_tile::HostTensor<Type> x_host(
{dim_in[0], dim_in[1], dim_in[2], dim_in[3]},
{stride_dim_in[0], stride_dim_in[1], stride_dim_in[2], stride_dim_in[3]});
ck_tile::HostTensor<Type> y_host(
{dim_out[0], dim_out[1], dim_out[2], dim_out[3]},
{stride_dim_out[0], stride_dim_out[1], stride_dim_out[2], stride_dim_out[3]});
ck_tile::FillUniformDistribution<Type>{-.5f, .5f}(x_host);
ck_tile::DeviceMem x_dev(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_dev(y_host.get_element_space_size_in_bytes());
x_dev.ToDevice(x_host.data());
auto trait = batched_transpose_trait{prec, layout_in};
uint32_t height = nchw2nhwc ? C : H * W;
uint32_t width = nchw2nhwc ? H * W : C;
batched_transpose_kargs karg = [&]() {
batched_transpose_kargs a_;
a_.p_input = x_dev.GetDeviceBuffer();
a_.p_output = y_dev.GetDeviceBuffer();
a_.batch = N;
a_.height = height;
a_.width = width;
return a_;
}();
ck_tile::stream_config sc{nullptr, true};
auto ms = batched_transpose(trait, karg, sc);
std::size_t num_operations = N * C * H * (W - 1);
std::size_t num_bytes = N * C * H * W * sizeof(Type);
float ave_time = ms * 1E-3;
float gb_per_sec = num_bytes / ms * 1.E-6;
float tflops = static_cast<float>(num_operations) / ms * 1.E-6;
std::cout << "Run Batched Transpose kernel with N=" << N << ", C=" << C << ", H=" << H
<< ", W=" << W << ", layout_in=" << layout_in << ", layout_out=" << layout_out
<< " : " << ms << " ms (" << ave_time << " ave_time), " << tflops << " TFlops"
<< gb_per_sec << " GB/s, " << std::endl;
printf("[%s]N:%d, C:%d, H:%d, W:%d, layout_in:%s, %f\n",
prec.c_str(),
N,
C,
H,
W,
layout_in.c_str(),
ms);
if(ms < 0)
printf("not supported\n");
fflush(stdout);
if(ms < 0)
{
return false;
}
y_dev.FromDevice(y_host.data());
bool rtn = true;
if(validate)
{
// this host buffer will not copy to GPU, so no need use stride
ck_tile::HostTensor<Type> y_ref(
{dim_out[0], dim_out[1], dim_out[2], dim_out[3]},
{stride_dim_out[0], stride_dim_out[1], stride_dim_out[2], stride_dim_out[3]});
ck_tile::reference_batched_transpose<Type>(x_host, y_ref, layout_in, layout_out);
auto [rtol, atol] = get_elimit<Type>("");
rtn &= ck_tile::check_err(
y_host, y_ref, std::string("y Error: Incorrect results!"), rtol, atol);
}
printf("valid:%s\n", rtn ? "y" : "n");
fflush(stdout);
return rtn;
}
int main(int argc, char** argv)
{
auto [result, args] = create_args(argc, argv);
if(!result)
return -1;
std::string prec = args.get_str("pr");
bool r = true;
if(prec.compare("fp16") == 0)
{
r &= run_batched_transpose<ck_tile::fp16_t>(args);
}
else if(prec.compare("fp8") == 0)
{
r &= run_batched_transpose<ck_tile::fp8_t>(args);
}
else
{
std::cerr << "Unsupported data type: " << prec << std::endl;
}
return r ? 0 : -1;
}

View File

@@ -0,0 +1,27 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/reduce.hpp"
#include "batched_transpose_kernel.hpp"
#include "block_transpose.hpp"
#include "transpose_policy.hpp"
#include <vector>
#include <string>
#pragma once
struct batched_transpose_trait
{
std::string type;
std::string layout;
};
struct batched_transpose_kargs : public ck_tile::BatchedTransposeHostArgs
{
};
float batched_transpose(batched_transpose_trait t,
batched_transpose_kargs a,
ck_tile::stream_config s);

View File

@@ -0,0 +1,151 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
struct TransposePolicy
{
static constexpr auto TileAccessPattern = tile_distribution_pattern::thread_raked;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSize()
{
return 16 / sizeof(typename Problem::DataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return integer_least_multiple(
sizeof(typename Problem::DataType) *
MakeLdsStoreBlockDescriptor<Problem>().get_element_space_size(),
16);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeInputDistribution()
{
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t LeadDimPerBlock = Problem::kLeadSizePerBlock;
constexpr index_t SecondDimPerBlock = Problem::kSecondSizePerBlock;
constexpr index_t VecLoadSize = 16 / sizeof(typename Problem::DataType);
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
SecondDimPerBlock,
LeadDimPerBlock,
VecLoadSize,
TileAccessPattern>;
return TileEncodingPattern::Make2DStaticTileDistribution();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOutputDistribution()
{
constexpr auto input_dstr = MakeLdsLoadTileDistribution<Problem>();
using OutTileDstrEncode =
typename OutputTileDistributionTraits<remove_cvref_t<decltype(input_dstr)>,
typename Problem::DataType>::OutDstrEncode;
constexpr auto block_dstr = make_static_tile_distribution(OutTileDstrEncode{});
return block_dstr;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsStoreBlockDescriptor()
{
constexpr index_t kLeadDimPerBlock = Problem::kLeadSizePerBlock;
constexpr index_t kSecondDimPerBlock = Problem::kSecondSizePerBlock;
constexpr index_t kVectorSize = 16 / sizeof(typename Problem::DataType);
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kSecondDimPerBlock>{},
number<kLeadDimPerBlock / kVectorSize>{},
number<kVectorSize>{}),
make_tuple(number<kLeadDimPerBlock>{}, number<kVectorSize>{}, number<1>{}),
number<kVectorSize>{},
number<1>{});
constexpr auto lds_block_desc = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(make_pass_through_transform(number<kSecondDimPerBlock>{}),
make_merge_transform(make_tuple(number<kLeadDimPerBlock / kVectorSize>{},
number<kVectorSize>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsLoadBlockDescriptor()
{
constexpr index_t kLeadDimPerBlock = Problem::kLeadSizePerBlock;
constexpr index_t kSecondDimPerBlock = Problem::kSecondSizePerBlock;
constexpr index_t kVectorSize = 8 / sizeof(typename Problem::DataType);
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kSecondDimPerBlock>{},
number<kLeadDimPerBlock / kVectorSize>{},
number<kVectorSize>{}),
make_tuple(number<kLeadDimPerBlock>{}, number<kVectorSize>{}, number<1>{}),
number<kVectorSize>{},
number<1>{});
constexpr auto lds_block_desc = transform_tensor_descriptor(
lds_block_desc_0,
make_tuple(make_pass_through_transform(number<kSecondDimPerBlock>{}),
make_merge_transform(make_tuple(number<kLeadDimPerBlock / kVectorSize>{},
number<kVectorSize>{}))),
make_tuple(sequence<0>{}, sequence<1, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return lds_block_desc;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsLoadTileDistribution()
{
using DataType = typename Problem::DataType;
// Extract base dimensions from the traits
constexpr index_t kBaseLeadDim = LaneGroupTransposeTraits<DataType>::kleadDim;
constexpr index_t kBaseSecondDim = LaneGroupTransposeTraits<DataType>::ksecondDim;
// Calculate block-level dimensions
constexpr index_t kLead = Problem::kLeadSizePerXdl;
constexpr index_t kSecond = Problem::kSecondSizePerXdl;
constexpr index_t kLeadIterPerWarp = Problem::kLeadXdlNumPerWarp;
constexpr index_t kSecondIterPerWarp = Problem::kSecondXdlNumPerWarp;
constexpr index_t kLeadNumWarps = Problem::kLeadNumWarps;
constexpr index_t kSecondNumWarps = Problem::kSecondNumWarps;
// Calculate repetitions of base pattern
constexpr index_t kLeadRepetitions = kLead / kBaseLeadDim;
constexpr index_t kSecondRepetitions = kSecond / kBaseSecondDim;
constexpr index_t kSecondDimIterations = Problem::kIterationsInSecondDim;
constexpr index_t kSecondDimStrSub = kSecondRepetitions / kSecondDimIterations;
constexpr auto xdllevel_dstr_encoding = make_transposed_distr_encode<DataType,
kSecondDimStrSub,
kSecondDimIterations,
kLeadRepetitions,
1>();
constexpr auto input_tile_encode =
InputTileDistributionEncoding<decltype(xdllevel_dstr_encoding),
kLeadIterPerWarp,
kSecondIterPerWarp,
kLeadNumWarps,
kSecondNumWarps>();
constexpr auto block_dstr = make_static_tile_distribution(input_tile_encode);
return block_dstr;
}
};
} // namespace ck_tile

View File

@@ -19,5 +19,6 @@ add_subdirectory(16_batched_gemm)
add_subdirectory(17_grouped_gemm)
add_subdirectory(18_flatmm)
add_subdirectory(19_gemm_multi_d)
add_subdirectory(20_grouped_convolution)
add_subdirectory(35_batched_transpose)
add_subdirectory(36_copy)
add_subdirectory(37_transpose)