Optimized GEMMs for MX FP4/8 (#2294)

Adds V3 GEMM pipeline for MX FP4 and MX FP8 
Adds V3 GEMM pipeline for MX FP4 with preshuffling
Adds MXFP4 GEMM tests (#2275)
Adds MXFP4 GEMM examples
Adds MXFP4 GEMMs to ckProfiler




Co-authored-by: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com>
Co-authored-by: Andriy Roshchenko <andriy.roshchenko@amd.com>
Co-authored-by: aska-0096 <haocwang@amd.com>
Co-authored-by: lalala-sh <Jiaxing.Wen@amd.com>
Co-authored-by: OscarXu <huaiguxu@amd.com>
Co-authored-by: mtgu0705 <mtgu@amd.com>
Co-authored-by: Ding, Yi <yi.ding@amd.com>
Co-authored-by: feifei14119 <feiw@amd.com>
Co-authored-by: Lin, Qun <qlin@amd.com>
Co-authored-by: joye <joye@amd.com>
Co-authored-by: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com>

[ROCm/composable_kernel commit: 00247e3c29]
This commit is contained in:
Andriy Roshchenko
2025-06-05 13:54:15 -06:00
committed by GitHub
parent 4fba4073d3
commit ab0540c5db
83 changed files with 8193 additions and 2165 deletions

View File

@@ -13,7 +13,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
* Added support for GKCYX layout for grouped convolution backward weight (NGCHW/GKCYX/NGKHW).
* Added support for GKCYX layout for grouped convolution backward data (NGCHW/GKCYX/NGKHW).
* Added support for Stream-K version of mixed fp8/bf16 GEMM
* Added GEMM pipeline for microscaling (MX) data types
* Added GEMM pipeline for microscaling (MX) FP8/FP4 data types
* Added support for FP16 2:4 structured sparsity to universal GEMM.
* Added support for Split K for grouped convolution backward data.
* Added logit soft-capping support for fMHA forward kernels.

View File

@@ -39,6 +39,12 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8_streamk_v3)
add_example_executable(example_gemm_xdl_bf16_v3 gemm_xdl_bf16_v3.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_bf16_v3)
set(GEMM_OPTIONS)
list(APPEND GEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-16")
example_compile_options(example_gemm_xdl_fp8_v3 PRIVATE ${GEMM_OPTIONS})
example_compile_options(example_gemm_xdl_bf16_v3 PRIVATE ${GEMM_OPTIONS})
list(APPEND gpu_list gfx942 gfx950)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
@@ -34,7 +34,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Xdl_C
//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 64, 64, 64, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<1, 8, 8>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>;
< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 64, 64, 64, 64, 8, 8, 32, 32, 2, 2, S<8, 1, 8>, S<1, 0, 2>, 2, 1, 0, S<8, 1, 8>, S<1, 0, 2>, 2, 1, 0, 1, 1, S<1, 8, 1, 8>, 4>;
// clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm<ADataType,

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
@@ -71,9 +71,9 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD
256, // BlockSize
256, // MPerBlock
128, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
64, // KPerBlock
16, // AK1
16, // BK1
32, // MPerXDL
32, // NPerXDL
4, // MXdlPerWave
@@ -84,14 +84,14 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceBatchedGemmMultiD
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
0, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
0, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
@@ -60,7 +60,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdlSplitKCShu
//######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//######| | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | Wave| Wave| Lengths_KBatch_K0_M_K1| | | PerVector| | Lengths_KBatch_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 2, 128, 32, 16, 4, 16, 16, 16, 1, 1, S<1, 2, 8, 8>, S<0, 2, 1, 3>, 3, 2, true, S<1, 2, 8, 8>, S<0, 2, 1, 3>, 3, 2, true, 1, 1, S<1, 32, 1, 4>, 4>;
< ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, GemmDefault, 2, 128, 32, 16, 4, 8, 16, 16, 1, 1, S<1, 4, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 32, 1, 4>, 4>;
// clang-format on
#else

View File

@@ -6,6 +6,39 @@ add_example_dependencies(example_gemm_mx example_gemm_mx_fp8)
add_example_executable(example_gemm_mx_bf8 gemm_mx_bf8.cpp)
add_example_dependencies(example_gemm_mx example_gemm_mx_bf8)
add_example_executable(example_gemm_mx_fp8_bf8 gemm_mx_fp8_bf8.cpp)
add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_bf8)
#add_example_executable(example_gemm_mx_fp8_bf8 gemm_mx_fp8_bf8.cpp)
# add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_bf8) TOFO: Fix RRR
add_example_executable(example_gemm_mx_fp4 gemm_mx_fp4.cpp)
add_example_dependencies(example_gemm_mx example_gemm_mx_fp4)
add_example_executable(example_gemm_mx_fp4_bpreshuffle gemm_mx_fp4_bpreshuffle.cpp)
add_example_dependencies(example_gemm_mx example_gemm_mx_fp4_bpreshuffle)
#add_example_executable(example_moe_gemm1_xdl_mx_fp4 moe_gemm1_xdl_mx_fp4.cpp)
# add_example_dependencies(example_gemm_mx example_moe_gemm1_xdl_mx_fp4) TODO: Fix
#add_example_executable(example_moe_gemm1_xdl_mx_fp4_bns moe_gemm1_xdl_mx_fp4_bns.cpp)
#add_example_dependencies(example_gemm_mx example_moe_gemm1_xdl_mx_fp4_bns)
#add_example_executable(example_moe_gemm2_xdl_mx_fp4 moe_gemm2_xdl_mx_fp4.cpp)
# add_example_dependencies(example_gemm_mx example_moe_gemm2_xdl_mx_fp4) TODO: Fix
#add_example_executable(example_moe_gemm2_xdl_mx_fp4_bns moe_gemm2_xdl_mx_fp4_bns.cpp)
#add_example_dependencies(example_gemm_mx example_moe_gemm2_xdl_mx_fp4_bns)
set(FP4_MXGEMM_OPTIONS)
list(APPEND FP4_MXGEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --amdgpu-use-amdgpu-trackers=1")
#list(APPEND FP4_MXGEMM_OPTIONS -v --save-temps -Wno-gnu-line-marker -ftemplate-backtrace-limit=0)
example_compile_options(example_gemm_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS})
example_compile_options(example_gemm_mx_fp4_bpreshuffle PRIVATE ${FP4_MXGEMM_OPTIONS})
# example_compile_options(example_moe_gemm1_xdl_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS})
# example_compile_options(example_moe_gemm2_xdl_mx_fp4 PRIVATE ${FP4_MXGEMM_OPTIONS})
# example_compile_options(example_moe_gemm1_xdl_mx_fp4_bns PRIVATE ${FP4_MXGEMM_OPTIONS})
# example_compile_options(example_moe_gemm2_xdl_mx_fp4_bns PRIVATE ${FP4_MXGEMM_OPTIONS})
set(FP8_MXGEMM_OPTIONS)
list(APPEND FP8_MXGEMM_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1 -mllvm --slp-threshold=-32")
#list(APPEND FP8_MXGEMM_OPTIONS -v --save-temps -Wno-gnu-line-marker -ftemplate-backtrace-limit=0)
example_compile_options(example_gemm_mx_fp8 PRIVATE ${FP8_MXGEMM_OPTIONS})
example_compile_options(example_gemm_mx_bf8 PRIVATE ${FP8_MXGEMM_OPTIONS})

View File

@@ -21,11 +21,11 @@ using BElementOp = PassThrough; // elementwise transformation for B matrix
using CElementOp = PassThrough; // elementwise transformation for C matrix
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
constexpr ck::index_t KPerBlock = 128;
constexpr ck::index_t KPerBlock = 256;
constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave;
constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v1;
constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3<
ALayout, // ALayout
@@ -45,32 +45,32 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffle
ScaleBlockSize, // ScaleBlockSize: Scaling block size
128, // BlockSize: Thread block size
128, // MPerBlock
16, // NPerBlock
32, // NPerBlock
KPerBlock, // KPerBlock
16, // AK1
16, // BK1
16, // MPerXDL
16, // NPerXDL
4, // MXdlPerWave
1, // NXdlPerWave
S<8, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
2, // NXdlPerWave
S<16, 8, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
16, // ABlockTransferSrcScalarPerVector
16, // ABlockTransferDstScalarPerVector_AK1
false, // ABlockLdsExtraM
S<8, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
true, // ABlockLdsExtraM
S<16, 8, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
16, // BBlockTransferSrcScalarPerVector
16, // BBlockTransferDstScalarPerVector_BK1
false, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
true, // BBlockLdsExtraN
2, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 16, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
2, // CShuffleBlockTransferScalarPerVector_NPerBlock
4, // CShuffleBlockTransferScalarPerVector_NPerBlock
BlkGemmPSched, // BlkGemmPipeSched
BlkGemmPVer, // BlkGemmPipelineVer
ADataType, // ComputeTypeA
@@ -83,6 +83,7 @@ int main(int argc, char* argv[])
ADataType,
BDataType,
XDataType,
XDataType,
CDataType,
ALayout,
BLayout,

View File

@@ -23,8 +23,9 @@
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using MFMA = ck::tensor_layout::gemm::MFMA;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
@@ -36,6 +37,8 @@ struct ExecutionConfig final
int init_method = 2; // (0=constant values, 1=integer values, 2=decimal values)
bool time_kernel = false; // (0=no, 1=yes)
int verbosity = 0; // (0=no info, 1=verbose info)
int warm_up = 10;
int repeat = 10;
};
struct ProblemSizeSplitK final
@@ -86,6 +89,8 @@ bool parse_cmd_args(int argc,
if(argc >= 12)
{
problem_size.KBatch = std::stoi(argv[11]);
config.warm_up = std::stoi(argv[12]);
config.repeat = std::stoi(argv[13]);
}
}
else
@@ -103,10 +108,90 @@ bool parse_cmd_args(int argc,
return true;
}
template <bool KLast>
void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K)
{
int MNXdlPack = 2;
int KXdlPack = 2;
int XdlMNThread = 16;
int XdlKThread = 64 / XdlMNThread;
int K0 = K / KXdlPack / XdlKThread; // KRepeat
// The 4 16x128 building blocks will be packed into 1 32x256 for F4
// The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4
// unfold the MN32xK(256/32) scale buffer
// 4 16 2 2
// To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack
// Then, MNRepeat->KRepeat
for(int n = 0; n < MN; ++n)
{
for(int k = 0; k < K; ++k)
{
int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat
int tempn = n % (XdlMNThread * MNXdlPack);
int n1 = tempn % XdlMNThread; // i XdlMNThread
int n2 = tempn / XdlMNThread; // i MNXdlPack
int k0 = k / (XdlKThread * KXdlPack); // i KRepeat
int tempk = k % (XdlKThread * KXdlPack);
int k1 = tempk % XdlKThread; // i XdlKThread
int k2 = tempk / XdlKThread; // i KXdlPack
int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 +
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack +
k2 * MNXdlPack + n2;
// src[n * K + k] = ck::type_convert<ck::e8m0_bexp_t>(static_cast<float>(powf(2.0f,
// 2-k)));
if constexpr(KLast)
dst[outputIndex] = src[n * K + k];
else
dst[outputIndex] = src[k * MN + n];
}
}
}
void preShuffleBuffer(const ck::f4x2_pk_t* src, ck::f4x2_pk_t* dst, int N, int K, int NXdl)
{
int KPack = 16;
int NLane = NXdl;
int KLane = 64 / NLane;
int K_pk = K / 2;
int K0 = K_pk / (KLane * KPack);
// K -> K0 KLane KPack
// N -> N0 NLane
// N, K -> N0 K0 KLane NLane KPack
int tempk;
for(int n = 0; n < N; ++n)
{
for(int k = 0; k < K_pk; ++k)
{
int n0 = n / NLane;
int n1 = n % NLane;
int k0 = k / (KLane * KPack);
tempk = k % (KLane * KPack);
int k1 = tempk / KPack;
int k2 = tempk % KPack;
int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
k1 * KPack * NLane + n1 * KPack + k2;
dst[outputIndex] = src[n * K_pk + k];
}
}
}
template <typename DeviceOpInstance,
typename ADataType,
typename BDataType,
typename XDataType,
typename XPackedDataType,
typename CDataType,
typename ALayout,
typename BLayout,
@@ -119,6 +204,8 @@ template <typename DeviceOpInstance,
ck::index_t ScaleBlockSize>
bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& config)
{
constexpr bool BPreShuffle = ck::is_same_v<BLayout, MFMA>;
using BRefLayout = ck::conditional_t<BPreShuffle, Col, BLayout>;
auto M = problem_size.M;
auto N = problem_size.N;
@@ -131,28 +218,19 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
auto f_host_tensor_descriptor =
[](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) {
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return HostTensorDescriptor({row, col}, {stride, 1});
}
else
{
return HostTensorDescriptor({row, col}, {1, stride});
}
};
auto f_get_default_stride =
[](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) {
if(stride == -1)
{
// give a chance if stride is -1, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return static_cast<ck::index_t>(col);
}
else
{
return static_cast<ck::index_t>(row);
}
}
else
return static_cast<ck::index_t>(stride);
@@ -172,16 +250,30 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
using AScaleLayout = Row;
using BScaleLayout = Col;
auto Scale_Stride_AM = f_get_default_stride(M, K / ScaleBlockSize, -1, AScaleLayout{});
auto Scale_Padded_M = (M + ScaleBlockSize - 1) / ScaleBlockSize * ScaleBlockSize;
auto Scale_Stride_AM =
f_get_default_stride(Scale_Padded_M, K / ScaleBlockSize, -1, AScaleLayout{});
auto Scale_Stride_BN = f_get_default_stride(K / ScaleBlockSize, N, -1, BScaleLayout{});
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
auto b_k_n =
std::make_shared<Tensor<BDataType>>(f_host_tensor_descriptor(K, N, StrideB, BRefLayout{}));
auto b_input = b_k_n;
if constexpr(BPreShuffle)
b_input = std::make_shared<Tensor<BDataType>>(
f_host_tensor_descriptor(K, N, StrideB, BRefLayout{})); // use layout only for size
// scales for A and B
Tensor<XDataType> a_m_k_scale(f_host_tensor_descriptor(
M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{})); // scales for A
Tensor<XDataType> b_k_n_scale(f_host_tensor_descriptor(
K / ScaleBlockSize, N, Scale_Stride_BN, BScaleLayout{})); // scales for B
Scale_Padded_M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{}));
Tensor<XDataType> b_k_n_scale(
f_host_tensor_descriptor(K / ScaleBlockSize, N, Scale_Stride_BN, BScaleLayout{}));
// shuffled scales for A and B
Tensor<XDataType> a_shuffled_scale(f_host_tensor_descriptor(
Scale_Padded_M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{}));
Tensor<XDataType> b_shuffled_scale(
f_host_tensor_descriptor(K / ScaleBlockSize, N, Scale_Stride_BN, BScaleLayout{}));
Tensor<CDataType> c_m_n_host_result(
f_host_tensor_descriptor(M, N, StrideC, CLayout{})); // host verification
@@ -192,18 +284,31 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
{
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "a_m_k_scale: " << a_m_k_scale.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n->mDesc << std::endl;
std::cout << "b_k_n_scale: " << b_k_n_scale.mDesc << std::endl;
std::cout << "c_m_n_device_result: " << c_m_n_device_result.mDesc << std::endl;
}
auto a_data_element = [](float x) {
if constexpr(ck::is_same_v<ADataType, ck::f4x2_pk_t>)
return ck::type_convert<ADataType>(ck::float2_t(x));
else
return ck::type_convert<ADataType>(x);
};
auto b_data_element = [](float x) {
if constexpr(ck::is_same_v<BDataType, ck::f4x2_pk_t>)
return ck::type_convert<BDataType>(ck::float2_t(x));
else
return ck::type_convert<BDataType>(x);
};
switch(config.init_method)
{
case 0: // Initializations for development and debugging
ck::utils::FillConstant<ADataType>{ck::type_convert<ADataType>(1.0f)}(a_m_k);
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(2.0f)}(a_m_k_scale);
ck::utils::FillConstant<BDataType>{ck::type_convert<BDataType>(0.5f)}(b_k_n);
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(1.0f)}(b_k_n_scale);
ck::utils::FillConstant<ADataType>{a_data_element(1.0f)}(a_m_k);
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(1.0f)}(a_m_k_scale);
ck::utils::FillConstant<BDataType>{b_data_element(2.0f)}(*b_k_n);
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(0.5f)}(b_k_n_scale);
if(config.verbosity > 0)
{
std::cout << "Init A = {1}" << std::endl;
@@ -216,29 +321,20 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 6}); // Z[-5,5]
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 6}); // Z[-5,5]
if constexpr(ck::is_same_v<XDataType, ck::e8m0_bexp_t>)
{
a_m_k_scale.GenerateTensorValue(
GeneratorTensor_2<XDataType>{125, 129}); // scales: {0.25, 0.5, 1, 2}
b_k_n_scale.GenerateTensorValue(
GeneratorTensor_2<XDataType>{125, 129}); // scales: {0.25, 0.5, 1, 2}
}
else
{
ck::utils::FillUniformDistributionIntegerValue<XDataType>{-1.0f, 1.0f}(a_m_k_scale);
ck::utils::FillUniformDistributionIntegerValue<XDataType>{-1.0f, 1.0f}(b_k_n_scale);
}
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 6}); // Z[-5,5]
b_k_n->GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 6}); // Z[-5,5]
static_assert(ck::is_same_v<XDataType, ck::e8m0_bexp_t>);
a_m_k_scale.GenerateTensorValue(
GeneratorTensor_2<XDataType>{120, 129}); // scales: {0.25, 0.5, 1, 2}
b_k_n_scale.GenerateTensorValue(
GeneratorTensor_2<XDataType>{125, 129}); // scales: {0.25, 0.5, 1, 2}
break;
case 2:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-2.0, 2.0});
a_m_k_scale.GenerateTensorValue(GeneratorTensor_3<XDataType>{powf(2.0f, -125.0f), 1.0f});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-2.0, 2.0});
b_k_n->GenerateTensorValue(GeneratorTensor_3<BDataType>{-2.0, 2.0});
b_k_n_scale.GenerateTensorValue(GeneratorTensor_3<XDataType>{powf(2.0f, -125.0f), 1.0f});
break;
@@ -249,20 +345,33 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
}
}
preShuffleScaleBuffer<ck::is_same_v<ALayout, Row>>(a_m_k_scale.mData.data(),
a_shuffled_scale.mData.data(),
Scale_Padded_M,
K / ScaleBlockSize);
preShuffleScaleBuffer<ck::is_same_v<BRefLayout, Col>>(
b_k_n_scale.mData.data(), b_shuffled_scale.mData.data(), N, K / ScaleBlockSize);
if constexpr(BPreShuffle)
{
int NPerXdl = 16; // Fixed 16
preShuffleBuffer(b_k_n->mData.data(), b_input->mData.data(), N, K, NPerXdl);
}
if(config.verbosity > 0)
std::cout << "Device memory allocation..." << std::endl;
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem a_scale_device_buf(sizeof(XDataType) * a_m_k_scale.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem b_scale_device_buf(sizeof(XDataType) * b_k_n_scale.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.GetElementSpaceSize());
DeviceMem a_scale_device_buf(sizeof(XDataType) * a_m_k_scale.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n->GetElementSpaceSize());
DeviceMem b_scale_device_buf(sizeof(XDataType) * b_k_n_scale.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.GetElementSpaceSize());
if(config.verbosity > 0)
std::cout << "Upload data to device..." << std::endl;
a_device_buf.ToDevice(a_m_k.mData.data());
a_scale_device_buf.ToDevice(a_m_k_scale.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
b_scale_device_buf.ToDevice(b_k_n_scale.mData.data());
a_scale_device_buf.ToDevice(a_shuffled_scale.mData.data());
b_device_buf.ToDevice(b_input->mData.data());
b_scale_device_buf.ToDevice(b_shuffled_scale.mData.data());
if(config.verbosity > 0)
std::cout << "Done." << std::endl;
@@ -275,9 +384,9 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
auto invoker = device_op.MakeInvoker();
auto argument =
device_op.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<XDataType*>(a_scale_device_buf.GetDeviceBuffer()),
static_cast<XPackedDataType*>(a_scale_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<XDataType*>(b_scale_device_buf.GetDeviceBuffer()),
static_cast<XPackedDataType*>(b_scale_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
M,
N,
@@ -299,13 +408,26 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
"not consistent with the supported device_gemm arguments.");
}
std::size_t total_size =
a_m_k.GetElementSpaceSizeInBytes() + b_k_n->GetElementSpaceSizeInBytes() +
a_m_k_scale.GetElementSpaceSizeInBytes() + b_k_n_scale.GetElementSpaceSizeInBytes() +
a_shuffled_scale.GetElementSpaceSizeInBytes() +
b_shuffled_scale.GetElementSpaceSizeInBytes();
const auto total_cnt = ck::math::integer_divide_ceil(512 * 1024 * 1024, total_size);
const int rotating_count = std::max(1, std::min(config.repeat, static_cast<int>(total_cnt)));
if(config.verbosity > 0)
{
std::cout << "Computing GEMM on device..." << std::endl << std::endl;
}
float ave_time =
invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, config.verbosity, 20, 50});
float ave_time = invoker.Run(argument,
StreamConfig{nullptr,
config.time_kernel,
config.verbosity,
config.warm_up,
config.repeat,
rotating_count > 1,
rotating_count});
bool res_verified = true;
if(config.do_verification > 0)
@@ -332,7 +454,7 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
auto ref_argument = ref_gemm.MakeArgument(a_m_k,
a_m_k_scale,
b_k_n,
*b_k_n,
b_k_n_scale,
c_m_n_host_result,
PassThrough{},
@@ -347,20 +469,21 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
std::cout << "Comparing results..." << std::endl;
}
if(config.init_method == 0)
{
auto expected = static_cast<float>(K);
auto computed = type_convert<float>(c_m_n_device_result(1, 12));
// if(config.init_method == 0)
// {
// auto expected = static_cast<float>(K);
// auto computed = type_convert<float>(c_m_n_device_result(1, 12));
res_verified = res_verified && std::abs(expected - computed) <= 0.0f;
std::cout << "\nExpected vs Computed: " << expected << " vs " << computed
<< ((res_verified) ? " (PASSED!)" : " (FAILED!)") << std::endl
<< std::endl;
}
// res_verified = res_verified && std::abs(expected - computed) <= 0.0f;
// std::cout << "\nExpected vs Computed: " << expected << " vs " << computed
// << ((res_verified) ? " (PASSED!)" : " (FAILED!)") << std::endl
// << std::endl;
// }
res_verified = res_verified && ck::utils::check_err(c_m_n_device_result,
c_m_n_host_result,
"Error: Incorrect results!");
res_verified =
res_verified &&
ck::utils::check_err(
c_m_n_device_result, c_m_n_host_result, "Error: Incorrect results!", 5e-1, 5e-1);
if(config.verbosity > 0 && res_verified)
std::cout << "Verification Successful!" << std::endl;
@@ -377,13 +500,14 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c
// partial sums(K/ScaleBlockSize)]
// FLOPS = 2 * M * N * K + 2 * M * N * K / ScaleBlockSize
std::size_t flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N * K / ScaleBlockSize;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N +
sizeof(XDataType) * (M * K + K * N) / ScaleBlockSize;
std::size_t num_btype =
sizeof(ADataType) * M * K / ck::packed_size_v<ADataType> +
sizeof(BDataType) * K * N / ck::packed_size_v<BDataType> + sizeof(CDataType) * M * N +
sizeof(XDataType) * M * K / ScaleBlockSize + sizeof(XDataType) * N * K / ScaleBlockSize;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
float gb_per_sec = static_cast<float>(num_btype) / 1e6f / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s, " << device_op.GetTypeString() << std::endl;
@@ -396,6 +520,7 @@ template <typename DeviceOpInstance,
typename ADataType,
typename BDataType,
typename XDataType,
typename XPackedDataType,
typename CDataType,
typename ALayout,
typename BLayout,
@@ -416,6 +541,7 @@ bool run_mx_gemm_example(int argc, char* argv[])
ADataType,
BDataType,
XDataType,
XPackedDataType,
CDataType,
ALayout,
BLayout,

View File

@@ -0,0 +1,105 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_mx_common.hpp"
using ADataType = ck::f4x2_pk_t;
using BDataType = ck::f4x2_pk_t;
// using ADataType = ck::f4_t;
// using BDataType = ck::f4_t;
using XDataType = ck::e8m0_bexp_t;
using XPackedDataType = int32_t;
using CDataType = ck::half_t;
using AccDataType = float;
using CShuffleDataType = CDataType;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough; // elementwise transformation for A matrix
using BElementOp = PassThrough; // elementwise transformation for B matrix
using CElementOp = PassThrough; // elementwise transformation for C matrix
constexpr ck::index_t DataPackedSize = 2; // Packed representation of data
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave;
constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3;
// AB DataType: f4x2_pk_t
// Mathmatically, all numbers are represented as f4x2.
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3<
ALayout, // ALayout
BLayout, // BLayout
CLayout, // CLayout
ADataType, // ADataType
XPackedDataType, // AScaleDataType
BDataType, // BDataType
XPackedDataType, // BScaleDataType
CDataType, // CDataType
AccDataType, // GemmAccDataType
CShuffleDataType, // CShuffleDataType
AElementOp, // AElementwiseOperation
BElementOp, // BElementwiseOperation
CElementOp, // CElementwiseOperation
GemmSpec, // GemmSpec
ScaleBlockSize, // ScaleBlockSize: Scaling block size
256, // BlockSize: Thread block size
256, // MPerBlock
256, // NPerBlock
KPerBlock, // KPerBlock
16, // AK1
16, // BK1
16, // MPerXDL
16, // NPerXDL
8, // MXdlPerWave
8, // NXdlPerWave
S<8, 32, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
16, // ABlockTransferSrcScalarPerVector
16, // ABlockTransferDstScalarPerVector_AK1
true, // ABlockLdsExtraM
S<8, 32, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
16, // BBlockTransferSrcScalarPerVector
16, // BBlockTransferDstScalarPerVector_BK1
true, // BBlockLdsExtraN
2, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
BlkGemmPSched, // BlkGemmPipeSched
BlkGemmPVer, // BlkGemmPipelineVer
ADataType, // ComputeTypeA
BDataType // ComputeTypeB
>;
int main(int argc, char* argv[])
{
return run_mx_gemm_example<DeviceOpInstance,
ADataType,
BDataType,
XDataType,
XPackedDataType,
CDataType,
ALayout,
BLayout,
CLayout,
AElementOp,
BElementOp,
CElementOp,
AccDataType,
CShuffleDataType,
ScaleBlockSize>(argc, argv)
? 0
: -1;
}

View File

@@ -0,0 +1,105 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "gemm_mx_common.hpp"
using ADataType = ck::f4x2_pk_t;
using BDataType = ck::f4x2_pk_t;
// using ADataType = ck::f4_t;
// using BDataType = ck::f4_t;
using XDataType = ck::e8m0_bexp_t;
using XPackedDataType = int32_t;
using CDataType = ck::half_t;
using AccDataType = float;
using CShuffleDataType = CDataType;
using ALayout = Row;
using BLayout = MFMA;
using CLayout = Row;
using AElementOp = PassThrough; // elementwise transformation for A matrix
using BElementOp = PassThrough; // elementwise transformation for B matrix
using CElementOp = PassThrough; // elementwise transformation for C matrix
constexpr ck::index_t DataPackedSize = 2; // Packed representation of data
constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2
constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave;
constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3;
// AB DataType: f4x2_pk_t
// Mathmatically, all numbers are represented as f4x2.
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3<
ALayout, // ALayout
BLayout, // BLayout
CLayout, // CLayout
ADataType, // ADataType
XPackedDataType, // AScaleDataType
BDataType, // BDataType
XPackedDataType, // BScaleDataType
CDataType, // CDataType
AccDataType, // GemmAccDataType
CShuffleDataType, // CShuffleDataType
AElementOp, // AElementwiseOperation
BElementOp, // BElementwiseOperation
CElementOp, // CElementwiseOperation
GemmSpec, // GemmSpec
ScaleBlockSize, // ScaleBlockSize: Scaling block size
256, // BlockSize: Thread block size
128, // MPerBlock
512, // NPerBlock
KPerBlock, // KPerBlock
16, // AK1
16, // BK1
16, // MPerXDL
16, // NPerXDL
8, // MXdlPerWave
8, // NXdlPerWave
S<8, 32, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
16, // ABlockTransferSrcScalarPerVector
16, // ABlockTransferDstScalarPerVector_AK1
true, // ABlockLdsExtraM
S<8, 32, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
16, // BBlockTransferSrcScalarPerVector
16, // BBlockTransferDstScalarPerVector_BK1
true, // BBlockLdsExtraN
2, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
BlkGemmPSched, // BlkGemmPipeSched
BlkGemmPVer, // BlkGemmPipelineVer
ADataType, // ComputeTypeA
BDataType // ComputeTypeB
>;
int main(int argc, char* argv[])
{
return run_mx_gemm_example<DeviceOpInstance,
ADataType,
BDataType,
XDataType,
XPackedDataType,
CDataType,
ALayout,
BLayout,
CLayout,
AElementOp,
BElementOp,
CElementOp,
AccDataType,
CShuffleDataType,
ScaleBlockSize>(argc, argv)
? 0
: -1;
}

View File

@@ -25,7 +25,7 @@ constexpr ck::index_t KPerBlock = 256;
constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave;
constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v1;
constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3<
ALayout, // ALayout
@@ -49,26 +49,26 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffle
KPerBlock, // KPerBlock
16, // AK1
16, // BK1
32, // MPerXDL
32, // NPerXDL
2, // MXdlPerWave
2, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
16, // MPerXDL
16, // NPerXDL
4, // MXdlPerWave
4, // NXdlPerWave
S<16, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
16, // ABlockTransferSrcScalarPerVector
16, // ABlockTransferDstScalarPerVector_AK1
false, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
true, // ABlockLdsExtraM
S<16, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
16, // BBlockTransferSrcScalarPerVector
16, // BBlockTransferDstScalarPerVector_BK1
false, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
true, // BBlockLdsExtraN
2, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
BlkGemmPSched, // BlkGemmPipeSched
@@ -83,6 +83,7 @@ int main(int argc, char* argv[])
ADataType,
BDataType,
XDataType,
XDataType,
CDataType,
ALayout,
BLayout,

View File

@@ -24,7 +24,7 @@ constexpr ck::index_t ScaleBlockSize = 32; // scaling block size
constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave;
constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v1;
constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3<
ALayout, // ALayout
@@ -43,30 +43,30 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffle
GemmSpec, // GemmSpec
ScaleBlockSize, // ScaleBlockSize: Scaling block size
256, // BlockSize: Thread block size
256, // MPerBlock
256, // NPerBlock
128, // KPerBlock
128, // MPerBlock
128, // NPerBlock
256, // KPerBlock
16, // AK1
8, // BK1
16, // MPerXDL
16, // NPerXDL
8, // MXdlPerWave
8, // NXdlPerWave
S<8, 32, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
4, // MXdlPerWave
4, // NXdlPerWave
S<16, 16, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
16, // ABlockTransferSrcScalarPerVector
16, // ABlockTransferDstScalarPerVector_AK1
false, // ABlockLdsExtraM
S<16, 16, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<32, 8, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<0, 2, 1>, // BBlockTransferThreadClusterArrangeOrder
S<0, 2, 1>, // BBlockTransferSrcAccessOrder
1, // BBlockTransferSrcVectorDim
16, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
false, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleMXdlPerWavePerShuffle
2, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
@@ -82,6 +82,7 @@ int main(int argc, char* argv[])
ADataType,
BDataType,
XDataType,
XDataType,
CDataType,
ALayout,
BLayout,

View File

@@ -222,12 +222,18 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
rocm_install(TARGETS ${EXAMPLE_NAME} COMPONENT examples)
set(result 0)
endif()
#message("add_example returns ${result}")
set(result ${result} PARENT_SCOPE)
endfunction(add_example_executable_no_testing EXAMPLE_NAME)
function(example_compile_options EXAMPLE_NAME)
if(TARGET ${EXAMPLE_NAME})
target_compile_options(${EXAMPLE_NAME} ${ARGN})
endif()
endfunction(example_compile_options)
# add all example subdir
file(GLOB dir_list LIST_DIRECTORIES true *)
FOREACH(subdir ${dir_list})

View File

@@ -35,6 +35,9 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
using ComputeTypeB = BDataType;
using AccType = float; // for now only support V_MFMA_SCALE_F32
static constexpr index_t APackedSize = packed_size_v<ComputeTypeA>;
static constexpr index_t BPackedSize = packed_size_v<ComputeTypeB>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
@@ -48,17 +51,24 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);
static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0);
static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2);
static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2);
// static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2);
static constexpr index_t B_K1 =
BTileDesc{}.GetLength(Number < BTileDesc{}.GetNumOfDimension() == 4 ? 3 : 2 > {});
static constexpr auto xdlops_gemm =
XdlopsGemm<ComputeTypeA, MPerXDL, NPerXDL, KPack, ComputeTypeB, TransposeC, true>{};
static constexpr auto xdlops_gemm = XdlopsGemm<ComputeTypeA,
MPerXDL,
NPerXDL,
KPack * APackedSize,
ComputeTypeB,
TransposeC,
true>{};
static constexpr index_t AMmaKStride = KPack;
static constexpr index_t BMmaKStride = KPack;
//> store rows/cols into thread registers in chunks of 16
//> e.g. [k0,...,k15,k64,...,k79] or [k0,...,k15,k32,...,k47]
static constexpr index_t KThreadChunk = 16;
static constexpr index_t KThreadChunk = 16 / sizeof(ComputeTypeA);
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
static constexpr index_t KRepeat = KPerThread / KPack;
@@ -67,22 +77,29 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
using HotLoopInstList =
ck::BlockwiseGemmXdlops_pipeline_hotloop_inst<BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
A_K1,
B_K1,
A_K1,
B_K1,
MRepeat,
NRepeat,
MPerXDL,
NPerXDL,
xdlops_gemm.KPerXdlops>;
// Hardcode to 2, for better 8-bit access pattern
static constexpr index_t MXdlPack = 2;
static constexpr index_t NXdlPack = 2;
static constexpr index_t KXdlPack = 2;
using HotLoopInstList = ck::BlockwiseGemmXdlops_pipeline_hotloop_inst< //
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
A_K1,
B_K1,
A_K1,
B_K1,
MRepeat,
NRepeat,
MPerXDL,
NPerXDL,
xdlops_gemm.KPerXdlops,
(packed_size_v<ComputeTypeA> > 1 || packed_size_v<ComputeTypeB> > 1)>;
static_assert(KPerThread % KPack == 0,
"Wrong KPack setting; try increasing KPerThread or decreasing KPack");
@@ -116,7 +133,7 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
return make_tuple(0, waveId_m, xdlops_a_idx[I1], KThreadChunk * xdlops_a_idx[I0]);
return make_tuple(0, waveId_m, 0, xdlops_a_idx[I1], KThreadChunk * xdlops_a_idx[I0]);
}
__device__ static auto CalculateBThreadOriginDataIndex()
@@ -127,7 +144,7 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
return make_tuple(0, waveId_n, xdlops_b_idx[I1], KThreadChunk * xdlops_b_idx[I0]);
return make_tuple(0, waveId_n, 0, xdlops_b_idx[I1], KThreadChunk * xdlops_b_idx[I0]);
}
template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
@@ -142,24 +159,27 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))),
make_tuple(
make_unmerge_transform(make_tuple(MRepeat / MXdlPack, MWaves, MXdlPack, MPerXDL))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
make_tuple(Sequence<0, 1, 2, 3>{}));
constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
make_tuple(
make_unmerge_transform(make_tuple(NRepeat / NXdlPack, NWaves, NXdlPack, NPerXDL))),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0, 1, 2>{}));
make_tuple(Sequence<0, 1, 2, 3>{}));
// We pack 2 mfma in M/N direction, so we need to divide by 2
const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
make_tuple(m0 / MXdlPack, waveId_m, m0 % MXdlPack, blk_idx[I0]))[I0];
const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
make_tuple(n0 / NXdlPack, waveId_n, n0 % NXdlPack, blk_idx[I1]))[I0];
return make_tuple(c_thread_m, c_thread_n);
}
using Tuple4 = decltype(CalculateAThreadOriginDataIndex());
using Tuple5 = decltype(CalculateAThreadOriginDataIndex());
/**
* @brief Constructor for BlockwiseGemmXdlops_mx_pipeline_base.
@@ -179,13 +199,12 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
* repeat dimensions.
*/
__host__ __device__
BlockwiseGemmXdlops_mx_pipeline_base(Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
Tuple4 b_origin = CalculateBThreadOriginDataIndex())
BlockwiseGemmXdlops_mx_pipeline_base(Tuple5 a_origin = CalculateAThreadOriginDataIndex(),
Tuple5 b_origin = CalculateBThreadOriginDataIndex())
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
{
static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
"wrong! Desc should be known at compile-time");
static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
@@ -221,6 +240,28 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
}
// XDL output supporting C_xdl = A_xdl * B_xdl, packed mfma
__host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
return make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat / MXdlPack>{},
Number<NRepeat / NXdlPack>{},
I1,
I1,
Number<MXdlPack>{},
Number<NXdlPack>{},
M0,
M1,
M2,
N));
}
__host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
@@ -262,6 +303,23 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
}
// XDL output supporting C_xdl = A_xdl * B_xdl_packed mfma
__host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3()
{
constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat / MXdlPack>{},
Number<NRepeat / NXdlPack>{},
Number<MWaves>{},
Number<NWaves>{},
Number<MXdlPack>{},
Number<NXdlPack>{},
Number<MPerXDL>{},
Number<NPerXDL>{}));
return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(
c_block_desc_m0_n0_m1_n1_m2_n2);
}
__host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
{
constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
@@ -314,45 +372,47 @@ struct BlockwiseGemmXdlops_mx_pipeline_base
c_grid_desc_g_m0_n0_m1_n1_m2_n2);
}
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k;
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k;
__host__ __device__ static constexpr auto GetCThreadDesc() { return c_thread_desc_; }
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_m3_k;
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_n3_k;
protected:
// M1, N1 as double buffer index
// Read buffer + Compute buffer
// A[M0, M1, M2, KPack]
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(Number<MRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}),
make_tuple(
Number<KPack>{}, Number<KRepeat * MRepeat * KPack>{}, Number<MRepeat * KPack>{}, I1));
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(make_tuple(
Number<MRepeat / MXdlPack>{}, I1, Number<MXdlPack>{}, Number<KRepeat>{}, Number<KPack>{}));
// B[N0, N1, N2, KPack]
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor(
make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}),
make_tuple(
Number<KPack>{}, Number<KRepeat * NRepeat * KPack>{}, Number<NRepeat * KPack>{}, I1));
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(make_tuple(
Number<NRepeat / NXdlPack>{}, I1, Number<NXdlPack>{}, Number<KRepeat>{}, Number<KPack>{}));
// C[M, N, NumRegXdlops]
static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
static constexpr auto c_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(Number<MRepeat / MXdlPack>{},
Number<NRepeat / NXdlPack>{},
Number<MXdlPack>{},
Number<NXdlPack>{},
xdlops_gemm.GetRegSizePerXdlops()));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
ComputeTypeA,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_block_desc_m0_m1_m2_m3_k),
decltype(a_thread_desc_),
Sequence<1, 1, 1, KThreadChunk>,
Sequence<0, 1, 2, 3>,
3,
Sequence<1, 1, 1, 1, KThreadChunk>,
Sequence<0, 1, 2, 3, 4>,
4,
A_K1,
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<BDataType,
ComputeTypeB,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_block_desc_n0_n1_n2_n3_k),
decltype(b_thread_desc_),
Sequence<1, 1, 1, KThreadChunk>,
Sequence<0, 1, 2, 3>,
3,
Sequence<1, 1, 1, 1, KThreadChunk>,
Sequence<0, 1, 2, 3, 4>,
4,
B_K1,
B_K1>;

View File

@@ -145,7 +145,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v3<BlockGemmPipelineSch
using Base::MWaves;
static constexpr auto xdlops_gemm =
XdlopsGemm<ComputeDataType, MPerXDL, NPerXDL, KPack, BDataType>{};
XdlopsGemm<ComputeDataType, MPerXDL, NPerXDL, KPack, ComputeDataType>{};
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;

View File

@@ -270,10 +270,10 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1<BlockGemmPipelineScheduler::I
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
__builtin_amdgcn_sched_barrier(0);
// // Local prefill A1
// Local prefill A1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
// // Global prefetch A2
// Global prefetch A2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);

View File

@@ -58,11 +58,21 @@ struct BlockwiseGemmXdlops_pipeline_base
static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
static constexpr index_t KRepeat = KPerThread / KPack;
static constexpr index_t KPerInnerLoop = KPack;
static constexpr index_t KGroup =
((MPerXDL == 16 && MPerXDL == 16 && xdlops_gemm.KPerXdlops == 128) ||
(MPerXDL == 32 && MPerXDL == 32 && xdlops_gemm.KPerXdlops == 64))
? 2
: 1;
static constexpr index_t KGroup = []() {
if constexpr(is_same_v<remove_cvref_t<ComputeDataType>, f8_t>)
// On gfx950, we have mfma that required 32 f8 elements as input,
// splited into 2 groups of 16 f8 elements.
// the 2 groups is not contiguous in the B preshuffed layout.
// and we do not want it to be contiguous in the B preshuffled layout
// because a memory instruction can only read 16 f8 elements at a time.
return ((MPerXDL == 16 && MPerXDL == 16 && xdlops_gemm.KPerXdlops == 128) ||
(MPerXDL == 32 && MPerXDL == 32 && xdlops_gemm.KPerXdlops == 64))
? 2
: 1;
else
return 1;
}();
static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);

View File

@@ -0,0 +1,68 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp"
namespace ck {
template <BlockGemmPipelineVersion BlkGemmPipelineVer,
BlockGemmPipelineScheduler BlkGemmPipeSche,
index_t ThreadBlockSize,
index_t ScaleBlockSize,
typename ADataType,
typename AScaleDataType,
typename BDataType,
typename BScaleDataType,
typename ComputeDataType, // TODO: remove this as in this pipeline ADataType and BDataType
// must be used for compute
typename AccDataType,
typename ATileDesc,
typename BTileDesc,
typename AMmaTileDesc,
typename BMmaTileDesc,
index_t ABlockTransferSrcScalarPerVector,
index_t BBlockTransferSrcScalarPerVector,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack>
constexpr auto BlockGemmMXBPreshufflePipeline_Selector()
{
// Hardware MX GEMM pipeline
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
return BlockwiseGemmXdlops_pipeline_v3_mx_bprehuffle<BlkGemmPipeSche,
ThreadBlockSize,
ScaleBlockSize,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
else
{
std::cerr << "MX GEMM Pipeline configuration is not available" << std::endl;
}
}
} // namespace ck

View File

@@ -4,38 +4,9 @@
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_mx.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx.hpp"
namespace ck {
/**
* @brief Define matrix data types that have hardware support for MX GEMMs
*/
template <typename T>
static constexpr bool is_scale_mfma_data_type()
{
return is_same_v<T, f8_ocp_t> || is_same_v<T, bf8_ocp_t> || is_same_v<T, f6_t> ||
is_same_v<T, bf6_t> || is_same_v<T, f4_t>;
}
/**
* @brief Define scale data types that have hardware support for MX GEMMs
*/
template <typename T>
static constexpr bool is_scale_mfma_scale_type()
{
return is_same_v<T, e8m0_bexp_t>;
}
/**
* @brief Combination of data types that have hardware support for MX GEMMs
*/
template <typename ADataType, typename BDataType, typename AScaleDataType, typename BScaleDataType>
static constexpr bool scale_mfma_hw_support()
{
return is_scale_mfma_data_type<ADataType>() && is_scale_mfma_data_type<BDataType>() &&
is_scale_mfma_scale_type<AScaleDataType>() && is_scale_mfma_scale_type<BScaleDataType>();
}
template <BlockGemmPipelineVersion BlkGemmPipelineVer,
BlockGemmPipelineScheduler BlkGemmPipeSche,
index_t ThreadBlockSize,
@@ -89,6 +60,30 @@ constexpr auto BlockGemmMXPipeline_Selector()
NRepeat,
KPack>{};
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
return BlockwiseGemmXdlops_pipeline_v3_mx<BlkGemmPipeSche,
ThreadBlockSize,
ScaleBlockSize,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
else
{
std::cerr << "MX GEMM Pipeline configuration is not available" << std::endl;

View File

@@ -205,7 +205,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
constexpr auto ds_read_a_issue_cycle =
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
constexpr auto ds_read_b_issue_cycle =

View File

@@ -136,15 +136,21 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::a_block_desc_m0_m1_m2_k;
using Base::b_block_desc_n0_n1_n2_k;
using Base::a_block_desc_m0_m1_m2_m3_k;
using Base::b_block_desc_n0_n1_n2_n3_k;
using Base::AMmaKStride;
using Base::APackedSize;
using Base::BMmaKStride;
using Base::BPackedSize;
using Base::KThreadChunk;
using Base::KXdlPack;
using Base::MXdlPack;
using Base::NXdlPack;
using AccType = typename Base::AccType;
using Tuple4 = typename Base::Tuple4;
using Tuple5 = typename Base::Tuple5;
using ComputeTypeA = typename Base::ComputeTypeA;
using ComputeTypeB = typename Base::ComputeTypeB;
@@ -156,11 +162,26 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
KPerBlock / ScaleBlockSize; // How many mx-vectors per K block
//> How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run()
static constexpr auto ScalesPerXdlopsRun = (KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize;
static constexpr auto AScalesPerXdlopsRun =
(APackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize;
static constexpr auto BScalesPerXdlopsRun =
(BPackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize;
//> How many scales a thread must read to accommodate one call to xdlops_gemm.Run()
static constexpr auto ScalesPerXdlopsRunPerThread =
ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks;
static constexpr auto ScalesPerXdlopsRunPerThreadA =
AScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks;
static constexpr auto ScalesPerXdlopsRunPerThreadB =
BScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks;
using mx_scale_t = e8m0_bexp_t;
static constexpr auto scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
static constexpr auto scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t);
static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0,
"A scale pack data type too large!");
static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
"B scale pack data type too large!");
static constexpr auto a_scale_thread_vec_size = KXdlPack * MXdlPack / scale_pack_size_a;
static constexpr auto b_scale_thread_vec_size = KXdlPack * NXdlPack / scale_pack_size_b;
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
{
@@ -232,76 +253,58 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
b_scale_thread_desc.GetElementSpaceSize());
// Global prefetch 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Prefetch a_scales
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
constexpr auto a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s));
auto a_scale_thread_buf_copy =
make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
a_scale_thread_desc_copy.GetElementSpaceSize());
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc_copy,
make_tuple(I0, I0),
a_scale_thread_buf_copy);
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc,
make_tuple(m0, k0, I0),
a_scale_thread_buf);
a_scale_thread_buf(Number<a_scale_offset>{}) =
a_scale_thread_buf_copy[Number<0>{}];
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc,
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
});
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
make_multi_index(0, I1, 0));
});
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize));
a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
});
// restore row id and advance to the next set of scales
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
make_multi_index(-MPerBlock, ScalesPerKBlockSize));
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc,
make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0));
// Prefetch b_scales
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
constexpr auto b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
auto b_scale_thread_buf_copy =
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
b_scale_thread_desc_copy.GetElementSpaceSize());
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc_copy,
make_tuple(I0, I0),
b_scale_thread_buf_copy);
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc,
make_tuple(n0, k0, I0),
b_scale_thread_buf);
b_scale_thread_buf(Number<b_scale_offset>{}) =
b_scale_thread_buf_copy[Number<0>{}];
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc,
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
});
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
make_multi_index(0, I1, 0));
});
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
});
// restore col id and advance to the next set of scales
// NWaves * NPerXDL * NRepeat == NPerBlock
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
make_multi_index(-NPerBlock, ScalesPerKBlockSize));
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc,
make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
// Local prefill 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
__builtin_amdgcn_s_waitcnt(3952); // wait for EXP_CNT, LDS, GDS, Constant and Message
block_sync_lds();
// Initialize C
c_thread_buf.Clear();
@@ -314,13 +317,8 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
do
{
// -------------------------------------------------------------------------------------------
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
block_sync_lds();
// wait previous blockwise copy to finish
// k indexes mapping to threads for 32x32x64:
// t0 : |0 --> 15 32 --> 47 | 64 --> 79 96 --> 111 | etc.
@@ -335,160 +333,184 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
// k = 0 k = 1
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step =
k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops);
k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops;
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step +
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(
a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
static_for<0, xdlops_gemm.K1PerXdlops / APackedSize / KThreadChunk, 1>{}(
[&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step +
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read block data in chunks to assemble correct thread vectors
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) {
constexpr auto b_k_step_chunk =
k_step +
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
b_thread_copy_.Run(
b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<b_k_step_chunk>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, k, Number<chunk * KThreadChunk>{}),
b_thread_buf);
});
static_for<0, xdlops_gemm.K1PerXdlops / BPackedSize / KThreadChunk, 1>{}(
[&](auto chunk) {
constexpr auto b_k_step_chunk =
k_step +
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
I0,
Number<b_k_step_chunk>{}),
b_block_buf,
b_thread_desc_,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
b_thread_buf);
});
});
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
// load for next k loop
block_sync_lds();
a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf);
b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
static_assert(0 < ScalesPerXdlopsRunPerThread,
static_assert(0 < ScalesPerXdlopsRunPerThreadA &&
0 < ScalesPerXdlopsRunPerThreadB,
"Must have at least one scale per Xdlops per Thread.");
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread>
a_scale_thread_vec;
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread>
b_scale_thread_vec;
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
// Pack scale_thread_buf into scale_thread_vec
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_buf[Number<a_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_buf[Number<b_scale_offset + s>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB, xdlops_gemm.K1PerXdlops>::type;
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
constexpr auto kxdl = ikxdl + k0 * KXdlPack;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
// MFMA accumulation
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<AScaleDataType>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<BScaleDataType>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
});
using mfma_input_type_a = typename vector_type< //
ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b = typename vector_type< //
ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
using mfma_scale_input_type_a = typename vector_type< //
AScaleDataType,
a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b = typename vector_type< //
BScaleDataType,
b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, imxdl, inxdl, 0));
// MFMA accumulation
xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
ikxdl * NXdlPack + inxdl>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec
.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec
.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(
Number<c_offset>{}));
});
});
});
});
});
});
// Prefetch a_scales
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
constexpr auto a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s));
auto a_scale_thread_buf_copy =
make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
a_scale_thread_desc_copy.GetElementSpaceSize());
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc_copy,
make_tuple(I0, I0),
a_scale_thread_buf_copy);
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc,
make_tuple(m0, k0, I0),
a_scale_thread_buf);
a_scale_thread_buf(Number<a_scale_offset>{}) =
a_scale_thread_buf_copy[Number<0>{}];
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc,
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
});
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
make_multi_index(0, I1, 0));
});
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc,
make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize));
a_scale_grid_desc, make_multi_index(MWaves, -KRepeat / KXdlPack, 0));
});
// restore row id and advance to the next set of scales
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, make_multi_index(-MPerBlock, ScalesPerKBlockSize));
a_scale_grid_desc,
make_multi_index(-MWaves * MRepeat / MXdlPack, KRepeat / KXdlPack, 0));
// Prefetch b_scales
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
constexpr auto b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s));
auto b_scale_thread_buf_copy =
make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
b_scale_thread_desc_copy.GetElementSpaceSize());
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc_copy,
make_tuple(I0, I0),
b_scale_thread_buf_copy);
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc,
make_tuple(n0, k0, I0),
b_scale_thread_buf);
b_scale_thread_buf(Number<b_scale_offset>{}) =
b_scale_thread_buf_copy[Number<0>{}];
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc,
make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize));
});
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
make_multi_index(0, I1, 0));
});
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc,
make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize));
b_scale_grid_desc, make_multi_index(NWaves, -KRepeat / KXdlPack, 0));
});
// restore col id and advance to the next set of scales
// NWaves * NPerXDL * NRepeat == NPerBlock
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize));
b_scale_grid_desc,
make_multi_index(-NWaves * NRepeat / NXdlPack, KRepeat / KXdlPack, 0));
__builtin_amdgcn_s_waitcnt(3952); // wait for EXP_CNT and LGKM_CNT
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
i += 1;
} while(i < (num_loop - 1));
@@ -497,87 +519,128 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
// tail
if constexpr(TailNum == TailNumber::Full)
{
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto k_step =
k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops);
k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops;
static_for<0, MRepeat, 1>{}([&](auto m0) {
// read block data in chunks to assemble correct thread
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
static_for<0, xdlops_gemm.K1PerXdlops / APackedSize / KThreadChunk, 1>{}(
[&](auto chunk) {
constexpr auto a_k_step_chunk =
k_step +
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
I0,
Number<a_k_step_chunk>{}),
a_block_buf,
a_thread_desc_,
make_tuple(Number<m0 / MXdlPack>{},
I0,
Number<m0 % MXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
a_thread_buf);
});
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
// read block data in chunks to assemble correct thread
static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) {
constexpr auto b_k_step_chunk =
k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<b_k_step_chunk>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, k, Number<chunk * KThreadChunk>{}),
b_thread_buf);
});
// read block data in chunks to assemble correct thread vectors
static_for<0, xdlops_gemm.K1PerXdlops / BPackedSize / KThreadChunk, 1>{}(
[&](auto chunk) {
constexpr auto b_k_step_chunk =
k_step +
chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
I0,
Number<b_k_step_chunk>{}),
b_block_buf,
b_thread_desc_,
make_tuple(Number<n0 / NXdlPack>{},
I0,
Number<n0 % NXdlPack>{},
k,
Number<chunk * KThreadChunk>{}),
b_thread_buf);
});
});
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
static_for<0, MRepeat / MXdlPack, 1>{}([&](auto m0) {
static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) {
static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) {
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0));
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0));
vector_type<AScaleDataType, ScalesPerXdlopsRunPerThread> a_scale_thread_vec;
vector_type<BScaleDataType, ScalesPerXdlopsRunPerThread> b_scale_thread_vec;
static_assert(0 < ScalesPerXdlopsRunPerThreadA &&
0 < ScalesPerXdlopsRunPerThreadB,
"Must have at least one scale per Xdlops per Thread.");
// Pack b_scale_thread_buf into b_scale_thread_vec
static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) {
vector_type<AScaleDataType, a_scale_thread_vec_size> a_scale_thread_vec;
vector_type<BScaleDataType, b_scale_thread_vec_size> b_scale_thread_vec;
// Pack scale_thread_buf into scale_thread_vec
static_for<0, a_scale_thread_vec_size, 1>{}([&](auto s) {
a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
a_scale_thread_buf[Number<a_scale_offset + s>{}];
});
static_for<0, b_scale_thread_vec_size, 1>{}([&](auto s) {
b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
b_scale_thread_buf[Number<b_scale_offset + s>{}];
});
using mfma_input_type_a =
typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
using mfma_input_type_b =
typename vector_type<ComputeTypeB, xdlops_gemm.K1PerXdlops>::type;
static_for<0, KXdlPack, 1>{}([&](auto ikxdl) {
static_for<0, MXdlPack, 1>{}([&](auto imxdl) {
static_for<0, NXdlPack, 1>{}([&](auto inxdl) {
constexpr auto kxdl = ikxdl + k0 * KXdlPack;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
vector_type<ComputeTypeA, KPack> a_thread_vec;
vector_type<ComputeTypeB, KPack> b_thread_vec;
// MFMA accumulation
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec.template AsType<AScaleDataType>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec.template AsType<BScaleDataType>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, imxdl, kxdl, ik))>{}];
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, inxdl, kxdl, ik))>{}];
});
using mfma_input_type_a = typename vector_type< //
ComputeTypeA,
xdlops_gemm.K1PerXdlops / APackedSize>::type;
using mfma_input_type_b = typename vector_type< //
ComputeTypeB,
xdlops_gemm.K1PerXdlops / BPackedSize>::type;
using mfma_scale_input_type_a = typename vector_type< //
AScaleDataType,
a_scale_thread_vec_size>::type;
using mfma_scale_input_type_b = typename vector_type< //
BScaleDataType,
b_scale_thread_vec_size>::type;
constexpr index_t c_offset = c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, imxdl, inxdl, 0));
// MFMA accumulation
xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
ikxdl * NXdlPack + inxdl>(
a_thread_vec.template AsType<mfma_input_type_a>(),
a_scale_thread_vec
.template AsType<mfma_scale_input_type_a>(),
b_thread_vec.template AsType<mfma_input_type_b>(),
b_scale_thread_vec
.template AsType<mfma_scale_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
});
});
});
@@ -587,20 +650,16 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
// TODO: make this field protected when a_scale_thread_copy_ is moved
// here
static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{}, Number<KRepeat>{}, Number<ScalesPerXdlopsRunPerThread>{}));
// Is used to copy data from a_scale_grid to a_scale_thread
static constexpr auto a_scale_thread_desc_copy =
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{}));
make_tuple(Number<MRepeat / MXdlPack>{},
Number<KRepeat / KXdlPack>{},
Number<ScalesPerXdlopsRunPerThreadA * a_scale_thread_vec_size>{}));
// TODO: make this field protected when b_scale_thread_copy_ is moved
// here
static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<NRepeat>{}, Number<KRepeat>{}, Number<ScalesPerXdlopsRunPerThread>{}));
// Is used to copy data from b_scale_grid to b_scale_thread_buf
static constexpr auto b_scale_thread_desc_copy =
make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{}));
make_tuple(Number<NRepeat / NXdlPack>{},
Number<KRepeat / KXdlPack>{},
Number<ScalesPerXdlopsRunPerThreadB * b_scale_thread_vec_size>{}));
protected:
using Base::a_thread_copy_;

View File

@@ -177,8 +177,8 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
constexpr auto ds_read_a_issue_cycle =
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
constexpr auto ds_read_b_issue_cycle =

View File

@@ -179,7 +179,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
constexpr auto ds_read_a_issue_cycle =
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
constexpr auto ds_read_b_issue_cycle =

View File

@@ -178,7 +178,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
constexpr auto ds_read_a_issue_cycle =
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
constexpr auto ds_read_b_issue_cycle =

File diff suppressed because it is too large Load Diff

View File

@@ -188,7 +188,7 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
constexpr auto ds_read_a_issue_cycle =
HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
constexpr auto ds_read_b_issue_cycle =

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -42,10 +42,12 @@ namespace ck {
template <typename ThreadGroup,
typename BlockSliceLengths,
typename ThreadClusterLengths,
typename ThreadClusterArrangeOrder,
typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename SrcDimAccessOrder,
index_t SrcVectorDim,
index_t DstVectorDim,
index_t ScalarPerVector>
@@ -61,6 +63,7 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto block_slice_lengths = BlockSliceLengths{};
static constexpr auto thread_cluster_lengths = ThreadClusterLengths{};
@@ -96,8 +99,12 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
// VALID: ThreadClusterLengths = [4, 16, 4] or [2, 32, 4] or [1, 64, 4] since in the
// first iteration, threads 0-63 write [0, 0, 0] - [0, 15, 7] -> 128 consecutive
// elements = 64 consecutive DWORDs.
#if defined(__gfx950__)
int num_contiguous_dwords = 4;
#else
int num_contiguous_dwords = 1;
bool is_contiguous = true;
#endif
bool is_contiguous = true;
static_for<0, nDim, 1>{}([&](auto i) {
if(is_contiguous)
{
@@ -141,11 +148,11 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
"When loading more than one element per thread at once, the contiguous "
"dimension must be the same between source and destination.");
constexpr auto dword_bytes = 4;
constexpr auto bytes_per_thread_load = ScalarPerVector * sizeof(SrcData);
static_assert(bytes_per_thread_load == dword_bytes,
"Direct load transfer requires each thread to load exactly a single "
"DWORD of data.");
// constexpr auto dword_bytes = 4;
// constexpr auto bytes_per_thread_load = ScalarPerVector * sizeof(SrcData);
// static_assert(bytes_per_thread_load == dword_bytes,
// "Direct load transfer requires each thread to load exactly a single "
// "DWORD of data.");
static_assert(nDim == remove_cvref_t<SrcDesc>::GetNumOfDimension() &&
nDim == remove_cvref_t<DstDesc>::GetNumOfDimension() &&
@@ -156,18 +163,45 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
"The number of threads cannot be less than the number of elements in "
"thread cluster lengths.");
static_assert(
AreThreadClusterLengthsValid(),
"Thread cluster lengths are incorrect. They must be set in a way that allows a single "
"wavefront to write contiguous DWORDs into LDS memory. ");
// static_assert(
// AreThreadClusterLengthsValid(),
// "Thread cluster lengths are incorrect. They must be set in a way that allows a single
// " "wavefront to write contiguous DWORDs into LDS memory. ");
const auto thread_cluster_idx =
thread_cluster_desc_.CalculateBottomIndex(make_multi_index(ThreadGroup::GetThreadId()));
constexpr auto wave_cluster_lengths = generate_sequence_v2(
[&](auto i) {
// FIXME: wave parallelism is not always in that dimension.
// The ThreadClusterLengths{} must be bigger than wave_num;
if constexpr(ThreadClusterArrangeOrder{}.At(i) == (nDim - 3))
{
return Number<ThreadGroup::GetNumOfThread() / 64>{};
}
else
{
return I1;
}
},
Number<nDim>{});
constexpr auto wave_thread_cluster_lengths = ThreadClusterLengths{} / wave_cluster_lengths;
constexpr auto wave_single_load_size =
wave_thread_cluster_lengths * thread_single_load_size;
constexpr auto wave_cluster_desc_ =
make_cluster_descriptor(wave_cluster_lengths, ThreadClusterArrangeOrder{});
const auto wave_cluster_idx = wave_cluster_desc_.CalculateBottomIndex(
make_multi_index(ThreadGroup::GetThreadId() / 64));
const auto thread_data_idx_begin = thread_cluster_idx * thread_single_load_size;
const auto wave_data_idx_begin = wave_cluster_idx * wave_single_load_size;
SetSrcSliceOrigin(src_desc, src_block_slice_origin + thread_data_idx_begin);
SetDstSliceOrigin(dst_desc, dst_block_slice_origin + thread_data_idx_begin);
// We don't need threadwise offset for lds since it was calculate by HW
// We still need input the wavewise offset.
SetDstSliceOrigin(dst_desc, dst_block_slice_origin + wave_data_idx_begin);
}
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
@@ -215,7 +249,7 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
// Loop over the destination block and copy data.
static_ford<decltype(dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
const auto src_offset = src_coord_.GetOffset();
const auto dst_offset = dst_coord_.GetOffset();
const auto dst_offset = __builtin_amdgcn_readfirstlane(dst_coord_.GetOffset());
// Check if src data is not in the logic padding area.
const bool is_src_valid =
@@ -303,7 +337,8 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
}
private:
static constexpr auto thread_cluster_desc_ = make_cluster_descriptor(ThreadClusterLengths{});
static constexpr auto thread_cluster_desc_ =
make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
SrcCoord src_coord_;
DstCoord dst_coord_;

View File

@@ -45,6 +45,44 @@ struct DeviceGemmMX : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename AScaleDataType,
typename BDataType,
typename BScaleDataType,
typename CDataType,
index_t ScaleBlockSize,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGemmMX_BPreshuffle : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_a_scale,
const void* p_b,
const void* p_b_scale,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideAScale,
ck::index_t StrideB,
ck::index_t StrideBScale,
ck::index_t StrideC,
ck::index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
virtual int GetPreShuffleParameters() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -15,6 +15,7 @@
#include "ck/tensor_operation/gpu/device/device_gemm_mx.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
@@ -162,56 +163,108 @@ struct DeviceGemmMX_Xdl_CShuffleV3 : public DeviceGemmMX<ALayout,
CElementwiseOperation>
{
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMX_xdl_cshuffle_v3<
ALayout,
BLayout,
CLayout,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
GemmAccDataType,
CShuffleDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
ScaleBlockSize,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
using GridwiseGemm = conditional_t< //
!is_same_v<BLayout, tensor_layout::gemm::MFMA>,
GridwiseGemmMX_xdl_cshuffle_v3<
ALayout,
BLayout,
CLayout,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
GemmAccDataType,
CShuffleDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
ScaleBlockSize,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>,
GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle<
ALayout,
BLayout,
CLayout,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
GemmAccDataType,
CShuffleDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
ScaleBlockSize,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>>;
using Argument = typename GridwiseGemm::Argument;
@@ -304,385 +357,45 @@ struct DeviceGemmMX_Xdl_CShuffleV3 : public DeviceGemmMX<ALayout,
: 1
: 2;
if(has_main_k_block_loop)
{
// Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(arg.KBatch > 1)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
// Tail number could be One to Seven
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::One>;
Run(kernel);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Two>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Three>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Four)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Five)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Seven)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::One>;
Run(kernel);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Two>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Three>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Four)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Five)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Seven)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
}
}
}
// Tail number could be Odd or Even
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
else
{
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
}
else
{
// Tail number always 1
constexpr auto TailNumChoices = []() {
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(arg.KBatch > 1)
return Tuple<constant<TailNumber::Full>>{};
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
return Tuple<constant<TailNumber::Even>, constant<TailNumber::Odd>>{};
else
static_assert(false, "Unexpected BlkGemmPipelineVer!");
}();
constexpr bool Use2LDS = []() {
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
return false;
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
return true;
else
static_assert(false, "Unexpected BlkGemmPipelineVer!");
}();
const TailNumber tail_num = GridwiseGemm::CalculateKBlockLoopTailNum(K_split);
using BoolChoices = Tuple<ck::true_type, ck::false_type>;
static_for_product<BoolChoices,
BoolChoices,
remove_cvref_t<decltype(TailNumChoices)>>{}(
[&](auto mainloop_choice, auto KBatch_cond_choice, auto tail_num_choice) {
constexpr auto CGlobalMemoryDataOperation =
KBatch_cond_choice.value ? InMemoryDataOperationEnum::AtomicAdd
: InMemoryDataOperationEnum::Set;
if(mainloop_choice.value == has_main_k_block_loop &&
KBatch_cond_choice.value == (arg.KBatch > 1) &&
tail_num_choice.value == tail_num)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
const auto kernel = kernel_gemm_xdl_cshuffle_v3_mx< //
Use2LDS,
GridwiseGemm,
mainloop_choice.value,
CGlobalMemoryDataOperation,
minimum_occupancy,
tail_num_choice.value>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
}
});
return ave_time;
}

View File

@@ -98,10 +98,12 @@ struct DeviceGemmXdlSplitKCShuffle_LdsDirectLoad : public DeviceGemmSplitK<ALayo
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferScalarPerVector,
ABlockLdsAddExtraM,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferScalarPerVector,
BBlockLdsAddExtraN,

View File

@@ -315,6 +315,13 @@ struct PassThrough
y = x;
}
template <>
__host__ __device__ void operator()<f4x2_pk_t, f4x2_pk_t>(f4x2_pk_t& y,
const f4x2_pk_t& x) const
{
y = x;
}
template <>
__host__ __device__ void operator()<double, double>(double& y, const double& x) const
{

View File

@@ -173,18 +173,34 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
__host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
{
// A matrix in LDS memory, destination of blockwise copy.
return make_naive_tensor_descriptor(
make_tuple(AK0PerBlock, Number<MPerBlock>{}, AK1),
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1, AK1, I1));
if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
// FIXME: our support to non-K contiguous layout is limited, only work in some specific
// setting
return make_naive_tensor_descriptor_packed(
make_tuple(AK0PerBlock, Number<MPerBlock>{}, AK1));
}
else
{
return make_naive_tensor_descriptor(make_tuple(AK0PerBlock, Number<MPerBlock>{}, AK1),
make_tuple(AK1, Number<KPerBlock>{}, I1));
}
}
__host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
{
// B matrix in LDS memory, destination of blockwise copy.
return make_naive_tensor_descriptor(
make_tuple(BK0PerBlock, Number<NPerBlock>{}, BK1),
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1, BK1, I1));
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
// FIXME: our support to non-K contiguous layout is limited, only work in some specific
// setting
return make_naive_tensor_descriptor_packed(
make_tuple(BK0PerBlock, Number<NPerBlock>{}, BK1));
}
else
{
return make_naive_tensor_descriptor(make_tuple(BK0PerBlock, Number<NPerBlock>{}, BK1),
make_tuple(BK1, Number<KPerBlock>{}, I1));
}
}
__host__ __device__ static constexpr auto
@@ -566,10 +582,12 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
Sequence<AK0PerBlock, MPerBlock, AK1>,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferSrcAccessOrder,
ADataType,
AComputeDataType,
decltype(a_grid_desc_ak0_m_ak1),
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
2,
ABlockTransferScalarPerVector>(
@@ -582,10 +600,12 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
Sequence<BK0PerBlock, NPerBlock, BK1>,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferSrcAccessOrder,
BDataType,
BComputeDataType,
decltype(b_grid_desc_bk0_n_bk1),
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
2,
BBlockTransferScalarPerVector>(

View File

@@ -256,8 +256,9 @@ struct GridwiseGemm_xdl_cshuffle_v3
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
// gfx950 double rate mfma16x16 require at least 128 KPerBlock to consume
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
lcm_AK1_BK1 < 32))
KPerBlock < 128 && MPerXdl == 16))
? true
: false;
static constexpr auto is_scale_mfma = false;

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -184,8 +184,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
(((is_same<ComputeTypeA, half_t>::value || is_same<ComputeTypeA, bhalf_t>::value) &&
lcm_AK1_BK1 <= 4) ||
(is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
// gfx950 double rate mfma16x16 require at least 128 KPerBlock to consume
((is_same<ComputeTypeA, f8_t>::value || is_same<ComputeTypeA, bf8_t>::value) &&
lcm_AK1_BK1 < 32))
KPerBlock < 128 && MPerXdl == 16))
? true
: false;
static constexpr auto is_scale_mfma = false;

View File

@@ -173,15 +173,25 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
lcm_AK1_BK1 < 32))
? true
: false;
static constexpr auto is_scale_mfma = false;
static constexpr auto mfma = MfmaSelector<ComputeTypeA,
static constexpr auto is_scale_mfma = false;
static constexpr auto mfma = MfmaSelector<ComputeTypeA,
MPerXdl,
NPerXdl,
ComputeTypeA,
is_single_rate_mfma,
is_scale_mfma>{};
static constexpr index_t KPack = math::max(lcm_AK1_BK1, mfma.selected_mfma.k_per_blk);
static constexpr index_t KGroup = mfma.selected_mfma.k_per_blk == 32 ? 2 : 1;
static constexpr index_t KPack = math::max(lcm_AK1_BK1, mfma.selected_mfma.k_per_blk);
static constexpr index_t KGroup = []() {
if constexpr(is_same_v<remove_cvref_t<BDataType>, f8_t>)
// On gfx950, we have a mfma that required 32 f8 elements as input,
// splited into 2 groups of 16 f8 elements.
// the 2 groups is not contiguous in the B preshuffed layout.
// and we do not want it to be contiguous in the B preshuffled layout
// because a memory instruction can only read 16 f8 elements at a time.
return mfma.selected_mfma.k_per_blk == 32 ? 2 : 1;
else
return 1;
}();
static constexpr index_t KLane = mfma.GetKPerXdlops() / mfma.GetK1PerXdlops();
static constexpr index_t KPackPerGroup = KPack / KGroup;
static constexpr index_t KRepeat = KPerBlock / KLane / KPackPerGroup;

View File

@@ -76,10 +76,12 @@ template <index_t BlockSize,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadClusterLengths_K0_M_K1,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_K0_N_K1,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
bool BBlockLdsExtraN,
@@ -102,9 +104,10 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load
static constexpr auto I7 = Number<7>{};
// K1 should be Number<...>
static constexpr auto K1 = Number<K1Value>{};
static constexpr auto M01 = 1;
static constexpr auto N01 = 1;
static constexpr auto K1 = Number<K1Value>{};
static constexpr auto KPerBlock = Number<K1Value * K0PerBlock>{};
static constexpr auto M01 = 1;
static constexpr auto N01 = 1;
static constexpr auto gemm_padder =
tensor_operation::device::GemmPadder<GemmSpec, index_t, index_t, index_t>{
@@ -613,8 +616,9 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
make_tuple(K1, Number<KPerBlock>{}, I1));
}
}();
@@ -630,9 +634,10 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load
}
else
{
return make_naive_tensor_descriptor_aligned(
return make_naive_tensor_descriptor(
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<MPerBlock>{}, K1),
max_lds_align);
make_tuple(
Number<KPerBlock>{} * Number<MPerBlock>{}, K1, Number<KPerBlock>{}, I1));
}
}();
// B matrix in LDS memory, dst of blockwise copy
@@ -645,8 +650,9 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load
}
else
{
return make_naive_tensor_descriptor_aligned(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
return make_naive_tensor_descriptor(
make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
make_tuple(K1, Number<KPerBlock>{}, I1));
}
}();
@@ -662,9 +668,10 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load
}
else
{
return make_naive_tensor_descriptor_aligned(
return make_naive_tensor_descriptor(
make_tuple(Number<1>{}, Number<K0PerBlock>{}, Number<NPerBlock>{}, K1),
max_lds_align);
make_tuple(
Number<KPerBlock>{} * Number<NPerBlock>{}, K1, Number<KPerBlock>{}, I1));
}
}();
@@ -672,10 +679,12 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load
ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
Sequence<1, K0PerBlock, MPerBlock, K1>,
ABlockTransferThreadClusterLengths_K0_M_K1,
ABlockTransferSrcAccessOrder,
FloatA,
ComputeType,
decltype(a_b_k0_m_k1_grid_desc),
decltype(a_b_k0_m_k1_block_desc),
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
3,
ABlockTransferSrcScalarPerVector>(
@@ -688,10 +697,12 @@ struct GridwiseGemm_xdlops_splitk_lds_direct_load
ThreadGroupTensorSliceTransfer_DirectLoad<ThisThreadBlock,
Sequence<1, K0PerBlock, NPerBlock, K1>,
BBlockTransferThreadClusterLengths_K0_N_K1,
BBlockTransferSrcAccessOrder,
FloatB,
ComputeType,
decltype(b_b_k0_n_k1_grid_desc),
decltype(b_b_k0_n_k1_block_desc),
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
3,
BBlockTransferSrcScalarPerVector>(

View File

@@ -260,7 +260,8 @@ struct ThreadwiseTensorSliceTransfer_v2
static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
"wrong! Not divisible");
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t> ||
is_same_v<remove_cvref_t<SrcData>, f4x2_pk_t>)
{
static_assert(SrcScalarPerVector % PackedSize == 0, "pk data N cannot be 1");
}
@@ -422,6 +423,240 @@ struct ThreadwiseTensorSliceTransfer_v2
SrcCoord src_coord_;
}; // namespace ck
template <typename SrcData,
typename DstData,
typename SrcDesc,
typename DstDesc,
typename SliceLengths,
typename DimAccessOrder,
index_t SrcVectorDim,
index_t SrcScalarPerVector,
index_t SrcScalarStrideInVector,
bool SrcResetCoordinateAfterRun,
index_t scale_gather_num,
bool InvalidElementAsNaN = false,
typename enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false>
struct ThreadwiseTensorSliceTransfer_v2_gather
{
static_assert((InvalidElementAsNaN && !ck::is_integral<DstData>::value) ||
(!InvalidElementAsNaN),
"Filling invalid element as NaN is only for floating point types");
static constexpr index_t nDim = SliceLengths::Size();
using Index = MultiIndex<nDim>;
using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
static constexpr index_t PackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
return 2;
else
return 1;
}();
__device__ constexpr ThreadwiseTensorSliceTransfer_v2_gather(
const SrcDesc& src_desc,
const Index& src_slice_origin_idx,
const StaticallyIndexedArray<index_t, scale_gather_num>& scale_gather_offsets)
: src_coord_(make_tensor_coordinate(src_desc, src_slice_origin_idx)),
scale_gather_offsets_(scale_gather_offsets)
{
static_assert(DstDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc need to known at compile-time");
static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
"wrong! Not divisible");
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
{
static_assert(SrcScalarPerVector % PackedSize == 0, "pk data N cannot be 1");
}
}
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
{
auto adjusted_origin_idx = [&]() {
Index idx;
static_for<0, nDim, 1>{}(
[&](auto i) { idx(i) = i.value == 0 ? 0 : src_slice_origin_idx[Number<i>{}]; });
return idx;
}();
src_coord_ = make_tensor_coordinate(src_desc, adjusted_origin_idx);
}
template <typename SrcBuffer, typename DstBuffer, typename DstSliceOriginIdx>
__device__ void Run(const SrcDesc& src_desc,
const SrcBuffer& src_buf,
const DstDesc&,
const DstSliceOriginIdx&,
DstBuffer& dst_buf)
{
static_assert(DstDesc::IsKnownAtCompileTime(),
"wrong! DstDesc need to known at compile-time");
static_assert(is_known_at_compile_time<remove_cvref_t<DstSliceOriginIdx>>::value,
"wrong! DstSliceOrigin need to known at compile-time");
static_assert(
is_same<remove_cvref_t<typename DstBuffer::type>, remove_cvref_t<DstData>>::value &&
"wrong! inconsistent type");
// DstDesc and dst_slice_origin_idx are known at compile-time
constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
constexpr auto dst_slice_origin_idx = DstSliceOriginIdx{};
// scalar per access on each dim
// TODO: don't use lambda_scalar_per_access
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
constexpr auto src_scalar_step_in_vector =
generate_sequence(detail::lambda_scalar_step_in_vector<SrcVectorDim>{}, Number<nDim>{});
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
remove_cv_t<decltype(src_scalar_per_access)>>;
// loop over tensor and copy
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
static_for<0, scale_gather_num, 1>{}([&](auto gather_idx) {
constexpr auto current_dst_origin =
to_multi_index(dst_slice_origin_idx) + make_multi_index(gather_idx, 0);
static_for<0, num_access, 1>{}([&](auto idx_1d) {
typename vector_type_maker<SrcData, SrcScalarPerVector / PackedSize>::type
src_vector;
using src_vector_t =
typename vector_type_maker<SrcData,
SrcScalarPerVector / PackedSize>::type::type;
constexpr auto src_data_idx = SpaceFillingCurve::GetIndex(idx_1d);
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc,
src_coord_);
// copy data from src_buf into src_vector
src_vector.template AsType<src_vector_t>()(Number<0>{}) =
src_buf.template Get<src_vector_t>(src_coord_.GetOffset() / PackedSize +
scale_gather_offsets_(gather_idx),
is_src_valid);
// copy data from src_vector into dst_buf
static_for<0, SrcScalarPerVector / PackedSize, 1>{}([&](auto i) {
constexpr index_t dst_offset =
dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) +
src_data_idx + i * src_scalar_step_in_vector);
constexpr auto full_dst_offset =
dst_desc.CalculateOffset(current_dst_origin) + dst_offset;
if constexpr(InvalidElementAsNaN)
{
dst_buf(full_dst_offset) =
is_src_valid
? type_convert<DstData>(src_vector.template AsType<SrcData>()[i])
: NumericLimits<DstData>::QuietNaN();
}
else
{
dst_buf(Number<full_dst_offset>{}) =
type_convert<DstData>(src_vector.template AsType<SrcData>()[i]);
}
});
if constexpr(idx_1d.value != num_access - 1)
{
constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
move_tensor_coordinate(
src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step));
}
});
});
// printf("blockIdx.y: %d, tid: %d, dst_buf<%f>\n",
// blockIdx.y,
// threadIdx.x,
// dst_buf(Number<0>{}));
// move src coordinate back to slice origin (or not)
if constexpr(SrcResetCoordinateAfterRun)
{
const auto src_reset_step =
make_tensor_coordinate_step(src_desc, GetSrcCoordinateResetStep());
move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
}
}
__device__ static constexpr auto GetSrcCoordinateResetStep()
{
constexpr auto src_scalar_per_access = generate_sequence(
detail::lambda_scalar_per_access<SrcVectorDim, SrcScalarPerVector>{}, Number<nDim>{});
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
DimAccessOrder,
remove_cv_t<decltype(src_scalar_per_access)>>;
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
if constexpr(num_access == 0)
{
return typename SpaceFillingCurve::Index{};
}
else
{
constexpr auto reset_step =
SpaceFillingCurve::GetStepBetween(Number<num_access - 1>{}, Number<0>{});
return reset_step;
}
}
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& src_slice_origin_step_idx)
{
// if src coord was not reset by Run(), then need to adjust the step here
const auto adjusted_step_idx =
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
}
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
template <typename SrcMoveSliceWindowStepHack>
__device__ void
MoveSrcSliceWindow(const SrcDesc& src_desc,
const Index& src_slice_origin_step_idx,
const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
{
// if src coord was not reset by RunRead(), then need to adjust the step here
const auto adjusted_step_idx =
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(
src_desc, adjusted_step_idx, src_move_slice_window_step_hack);
move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
}
private:
SrcCoord src_coord_;
StaticallyIndexedArray<index_t, scale_gather_num> scale_gather_offsets_;
}; // namespace ck
// Assume:
// 1. src_desc and dst_desc are not known at compile-time
// 2. SrcBuffer and DstBuffer are DynamicBuffer
@@ -1053,10 +1288,8 @@ struct ThreadwiseTensorSliceTransfer_v4
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
"wrong! SrcDesc and DstDesc need to known at compile-time");
static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
"wrong! Not divisible");
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t> ||
is_same_v<remove_cvref_t<SrcData>, f4x2_pk_t>)
{
static_assert(SrcScalarPerVector % PackedSize == 0, "pk data N cannot be 1");
}
@@ -1236,16 +1469,16 @@ struct ThreadwiseTensorSliceTransfer_v4
{
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
// DstData)
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
vector_type_maker_t<DstData, SrcScalarPerVector / PackedSize> dst_tmp_vector;
// TODO: if SrcData and DstData are vetor type, then static_cast may not compile
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
static_for<0, SrcScalarPerVector / PackedSize, 1>{}([&](auto i) {
dst_tmp_vector.template AsType<DstData>()(i) =
type_convert<DstData>(src_tmp_vector.template AsType<SrcData>()[i]);
});
// copy data from dst_tmp_vector into dst_buf
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
static_for<0, SrcScalarPerVector / PackedSize, 1>{}([&](auto i) {
constexpr index_t dst_offset = dst_desc.CalculateOffset(
dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);

View File

@@ -62,6 +62,18 @@ struct lambda_scalar_per_access_for_src_and_dst
}
};
template <index_t WaveNum, index_t nDim>
struct lambda_wave_cluster_dimension
{
__host__ __device__ constexpr auto operator()(index_t i) const
{
if((nDim - i) == 3)
return WaveNum;
else
return 1;
}
};
} // namespace detail
} // namespace ck

View File

@@ -90,7 +90,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
src_element_op_(src_element_op),
dst_element_op_(dst_element_op)
{
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
if constexpr((packed_size_v<SrcData>) > 1)
{
static_assert(is_same_v<remove_cvref_t<SrcData>, remove_cvref_t<DstData>>,
"SrcData != DstData");
@@ -99,7 +99,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
SrcScalarPerVector_ % PackedSize == 0 && DstScalarPerVector_ % PackedSize == 0,
"SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1 for packed data type");
static_assert(SrcVectorDim == DstVectorDim, "pk_i4_t does not support transpose");
static_assert(SrcVectorDim == DstVectorDim,
"Packed data type does not support transpose");
}
}
@@ -444,6 +445,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
{
static_assert(!is_same_v<remove_cvref_t<SrcData>, pk_i4_t>,
"in-register transpose is not supported for pk_i4_t");
static_assert(!is_same_v<remove_cvref_t<SrcData>, f4x2_pk_t>,
"in-register transpose is not supported for f4x2_pk_t");
// each transpose does
// DstScalarPerVector # of src vectors in src_thread_scratch_
// SrcScalarPerVector # of dst vectors in dst_thread_scratch_

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -96,7 +96,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
dst_element_op_(dst_element_op),
gather_offsets_(gather_offsets)
{
if constexpr(is_same_v<remove_cvref_t<SrcData>, pk_i4_t>)
if constexpr((packed_size_v<SrcData>) > 1)
{
static_assert(is_same_v<remove_cvref_t<SrcData>, remove_cvref_t<DstData>>,
"SrcData != DstData");
@@ -105,7 +105,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
SrcScalarPerVector_ % PackedSize == 0 && DstScalarPerVector_ % PackedSize == 0,
"SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1 for packed data type");
static_assert(SrcVectorDim == DstVectorDim, "pk_i4_t does not support transpose");
static_assert(SrcVectorDim == DstVectorDim,
"Packed data type does not support transpose");
}
}
@@ -222,7 +223,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
auto gather_offset =
gather_offsets_(ordered_src_access_idx[Number<ordered_gather_dim>{}]);
const IndexType ld_offset = src_coord_.GetOffset() + gather_offset;
const IndexType ld_offset = src_coord_.GetOffset() / PackedSize + gather_offset;
src_oob_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<bool>(src_data_idx_seq, true);

View File

@@ -410,8 +410,6 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
using dst_vector_t = typename remove_cvref_t<decltype(dst_vectors[i])>::type;
IndexType dst_offset = scatter_offset + (dst_coords_[i].GetOffset());
const bool is_dst_valid = dst_offset < dst_descs[i].GetElementSpaceSize();
// coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i],
// dst_coords_[i]);
constexpr InMemoryDataOperationEnum DstInMemOp =
static_cast<InMemoryDataOperationEnum>(DstInMemOps::At(i.value));
dst_bufs(i).template Update<DstInMemOp, dst_vector_t>(

View File

@@ -8,6 +8,35 @@
#include "ck/utility/amd_xdlops.hpp"
namespace ck {
/**
* @brief Define matrix data types that have hardware support for MX GEMMs
*/
template <typename T>
static constexpr bool is_scale_mfma_data_type()
{
using U = element_type_t<T>;
return is_same_v<U, f8_ocp_t> || is_same_v<U, bf8_ocp_t> || is_same_v<U, f6_t> ||
is_same_v<U, bf6_t> || is_same_v<U, f4_t>;
}
/**
* @brief Define scale data types that have hardware support for MX GEMMs
*/
template <typename T>
static constexpr bool is_scale_mfma_scale_type()
{
return is_same_v<T, e8m0_bexp_t>;
}
/**
* @brief Combination of data types that have hardware support for MX GEMMs
*/
template <typename ADataType, typename BDataType, typename AScaleDataType, typename BScaleDataType>
static constexpr bool scale_mfma_hw_support()
{
return is_scale_mfma_data_type<ADataType>() && is_scale_mfma_data_type<BDataType>() &&
is_scale_mfma_scale_type<AScaleDataType>() && is_scale_mfma_scale_type<BScaleDataType>();
}
enum struct MfmaInstr
{
@@ -847,6 +876,8 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4>
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t OpselA,
index_t OpselB,
class FloatA,
class ScaleA,
class FloatB,
@@ -858,11 +889,9 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4>
const ScaleB& scale_b,
FloatC& reg_c) const
{
static_assert(scalar_type<ScaleA>::vector_size == 1, "Expect single scale at this point.");
static_assert(scalar_type<ScaleB>::vector_size == 1, "Expect single scale at this point.");
intrin_mfma_scale_f32_32x32x64f8f6f4<MPerXdlops, NPerXdlops>::Run(
a, utils::get_exponent_value(scale_a), b, utils::get_exponent_value(scale_b), reg_c);
intrin_mfma_scale_f32_32x32x64f8f6f4<MPerXdlops, NPerXdlops, OpselA, OpselB>::Run(
a, bit_cast<uint32_t>(scale_a), b, bit_cast<uint32_t>(scale_b), reg_c);
}
};
@@ -885,6 +914,8 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>
template <index_t MPerXdlops,
index_t NPerXdlops,
index_t OpselA,
index_t OpselB,
class FloatA,
class ScaleA,
class FloatB,
@@ -896,11 +927,9 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>
const ScaleB& scale_b,
FloatC& reg_c) const
{
static_assert(scalar_type<ScaleA>::vector_size == 1, "Expect single scale at this point.");
static_assert(scalar_type<ScaleB>::vector_size == 1, "Expect single scale at this point.");
intrin_mfma_scale_f32_16x16x128f8f6f4<MPerXdlops, NPerXdlops>::Run(
a, utils::get_exponent_value(scale_a), b, utils::get_exponent_value(scale_b), reg_c);
intrin_mfma_scale_f32_16x16x128f8f6f4<MPerXdlops, NPerXdlops, OpselA, OpselB>::Run(
a, bit_cast<uint32_t>(scale_a), b, bit_cast<uint32_t>(scale_b), reg_c);
}
};
@@ -1117,7 +1146,7 @@ struct MfmaSelector
#endif
}
// Use singal rate mfma instruction for this special case A (f8_t) * B (pk_i4_t)
// Use single rate mfma instruction for this special case A (f8_t) * B (pk_i4_t)
// See example gemm_xdl_fp8_pk_i4_bpreshuffle_v3
// TODO: explore optimization opportunity by using new mfma instructions on gfx950
template <>
@@ -1153,6 +1182,16 @@ struct MfmaSelector
{
return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4;
}
template <>
constexpr auto GetMfma<f4_t, 32, 32, f4_t, false, true>()
{
return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4;
}
template <>
constexpr auto GetMfma<f4_t, 16, 16, f4_t, false, true>()
{
return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4;
}
template <>
constexpr auto GetMfma<f8_t, 16, 16, f8_t, true, false>()
@@ -1290,10 +1329,10 @@ struct MfmaSelector
#endif
}
static constexpr auto selected_mfma = mfma_type<GetMfma<base_type,
static constexpr auto selected_mfma = mfma_type<GetMfma<element_type_t<base_type>,
MPerXdlops,
NPerXdlops,
additional_type,
element_type_t<additional_type>,
is_single_rate_mfma,
is_scale_mfma>()>{};
@@ -1375,7 +1414,8 @@ struct XdlopsGemm
MPerXdlops == 64,
"Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack should be a multiple of k_per_blk");
static_assert(KPack * 2 % mfma_instr.k_per_blk == 0,
"KPack should be a multiple of k_per_blk");
}
// XDL output supporting C = A * B
@@ -1413,6 +1453,49 @@ struct XdlopsGemm
Sequence<7>{}));
}
// XDL output supporting C = A * B
// M3_N3 -> M3_M4_M5_N3
template <typename CDesc_M0_N0_M1_N1_M2_N2>
__host__ __device__ static constexpr auto MakeCDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(
const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
{
const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
const auto M2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I4);
const auto N2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I5);
return transform_tensor_descriptor(
c_desc_m0_n0_m1_n1_m2_n2,
make_tuple(make_pass_through_transform(M0),
make_pass_through_transform(N0),
make_pass_through_transform(M1),
make_pass_through_transform(N1),
make_pass_through_transform(M2),
make_pass_through_transform(N2),
make_unmerge_transform(make_tuple(Number<mfma_instr.num_groups_per_blk>{},
Number<mfma_instr.num_input_blks>{},
Number<mfma_instr.group_size>{})),
make_pass_through_transform(Number<mfma_instr.num_threads_per_blk>{})),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6, 7, 8>{},
Sequence<9>{}));
}
// transposed XDL output supporting C' = B' * A'
// M2_N2 -> M2_N2_N3_N4
template <typename CDesc_M0_N0_M1_N1_M2_N2>
@@ -1518,7 +1601,13 @@ struct XdlopsGemm
});
}
template <class FloatA, class ScaleA, class FloatB, class ScaleB, class FloatC>
template <index_t OpselA,
index_t OpselB,
class FloatA,
class ScaleA,
class FloatB,
class ScaleB,
class FloatC>
__device__ void Run(const FloatA& p_a_wave,
const ScaleA& a_scale_thread,
const FloatB& p_b_wave,
@@ -1528,12 +1617,12 @@ struct XdlopsGemm
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
if constexpr(!TransposeC)
{
mfma_instr.template run<MPerXdlops, NPerXdlops>(
mfma_instr.template run<MPerXdlops, NPerXdlops, OpselA, OpselB>(
p_a_wave[k], a_scale_thread[k], p_b_wave[k], b_scale_thread[k], p_c_thread);
}
else
{
mfma_instr.template run<MPerXdlops, NPerXdlops>(
mfma_instr.template run<MPerXdlops, NPerXdlops, OpselB, OpselA>(
p_b_wave[k], b_scale_thread[k], p_a_wave[k], a_scale_thread[k], p_c_thread);
}
});

View File

@@ -430,7 +430,9 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
(is_same<T, bf8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, uint8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, pk_i4_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
(is_same<T, pk_i4_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(is_same<T, f4x2_pk_t::type>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented");
using r_t = typename vector_type<T, N>::type;
@@ -1018,18 +1020,18 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
const index_t src_element_space_size)
{
// Direct loads require that each thread reads and writes exactly a single DWORD.
constexpr auto dword_bytes = 4;
constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread;
#if defined(__gfx950__)
constexpr auto dword_bytes = 4;
static_assert(bytes_per_thread == dword_bytes || bytes_per_thread == dword_bytes * 3 ||
bytes_per_thread == dword_bytes * 4);
#elif defined(__gfx942__)
constexpr auto dword_bytes = 4;
static_assert(bytes_per_thread == dword_bytes);
#ifndef CK_CODE_GEN_RTC
const uint32_t* global_ptr =
reinterpret_cast<uint32_t*>(reinterpret_cast<uintptr_t>(global_base_ptr));
#else
const uint32_t* global_ptr =
reinterpret_cast<uint32_t*>(reinterpret_cast<size_t>(global_base_ptr));
#endif
const int32x4_t src_resource = make_wave_buffer_resource(global_ptr, src_element_space_size);
const int32x4_t src_resource =
make_wave_buffer_resource(global_base_ptr, src_element_space_size);
const index_t global_offset_bytes = is_valid ? global_offset * sizeof(T) : 0x80000000;
#if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM
@@ -1057,7 +1059,7 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
#endif
llvm_amdgcn_raw_buffer_load_lds(
src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0);
src_resource, lds_ptr, bytes_per_thread, global_offset_bytes, 0, 0, 0);
#endif
}
#endif

View File

@@ -843,14 +843,8 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread;
static_assert(bytes_per_thread == dword_bytes);
#ifndef CK_CODE_GEN_RTC
const uint32_t* global_ptr =
reinterpret_cast<uint32_t*>(reinterpret_cast<uintptr_t>(global_base_ptr));
#else
const uint32_t* global_ptr =
reinterpret_cast<uint32_t*>(reinterpret_cast<size_t>(global_base_ptr));
#endif
const int32x4_t src_resource = make_wave_buffer_resource(global_ptr, src_element_space_size);
const int32x4_t src_resource =
make_wave_buffer_resource(global_base_ptr, src_element_space_size);
const index_t global_offset_bytes = is_valid ? global_offset * sizeof(T) : 0x80000000;
#if CK_USE_AMD_LDS_DIRECT_LOAD_INLINE_ASM

View File

@@ -662,11 +662,11 @@ struct intrin_mfma_f32_32x32x64f8f6f4<32, 32>
}
};
template <index_t MPerWave, index_t NPerWave>
template <index_t MPerWave, index_t NPerWave, index_t OpselA, index_t OpselB>
struct intrin_mfma_scale_f32_32x32x64f8f6f4;
template <>
struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32>
template <index_t OpselA, index_t OpselB>
struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32, OpselA, OpselB>
{
template <class FloatC>
__device__ static void Run(const f8x32_t& reg_a,
@@ -682,11 +682,11 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32>
reg_a,
reg_b,
reg_c.template AsType<float16_t>()[Number<0>{}],
0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
0, // blgp
0, // OPSEL
0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
0, // blgp
OpselA, // OPSEL
scale_a,
0, // OPSEL
OpselB, // OPSEL
scale_b);
// XXX: Note on the scale_a and scale_b parameters:
// If compiler detects that one or both scales are constant values, it will treat that
@@ -719,11 +719,11 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32>
reg_a,
reg_b,
reg_c.template AsType<float16_t>()[Number<0>{}],
1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
1, // blgp
0, // OPSEL
1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
1, // blgp
OpselA, // OPSEL
scale_a,
0, // OPSEL
OpselB, // OPSEL
scale_b);
// XXX: Note on the scale_a and scale_b parameters:
// If compiler detects that one or both scales are constant values, it will treat that
@@ -756,11 +756,11 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32>
reg_a,
reg_b,
reg_c.template AsType<float16_t>()[Number<0>{}],
1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
0, // blgp
0, // OPSEL
1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
0, // blgp
OpselA, // OPSEL
scale_a,
0, // OPSEL
OpselB, // OPSEL
scale_b);
// XXX: Note on the scale_a and scale_b parameters:
// If compiler detects that one or both scales are constant values, it will treat that
@@ -798,11 +798,11 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32>
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
reg_c.template AsType<float16_t>()[Number<0>{}],
2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
2, // blgp
0, // OPSEL
2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
2, // blgp
OpselA, // OPSEL
scale_a,
0, // OPSEL
OpselB, // OPSEL
scale_b);
#else
ignore = reg_a;
@@ -832,11 +832,11 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32>
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
reg_c.template AsType<float16_t>()[Number<0>{}],
3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
3, // blgp
0, // OPSEL
3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
3, // blgp
OpselA, // OPSEL
scale_a,
0, // OPSEL
OpselB, // OPSEL
scale_b);
#else
ignore = reg_a;
@@ -866,11 +866,11 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32>
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
reg_c.template AsType<float16_t>()[Number<0>{}],
4, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
4, // blgp
0, // OPSEL
4, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
4, // blgp
OpselA, // OPSEL
scale_a,
0, // OPSEL
OpselB, // OPSEL
scale_b);
#else
ignore = reg_a;
@@ -881,13 +881,60 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32>
#endif
}
};
#define BUILTIN_AMDGCN_MFMA_SCALE_F32_16X16X128_F8F6F4_WORKS 1
template <index_t MPerWave, index_t NPerWave>
#ifndef BUILTIN_AMDGCN_MFMA_SCALE_F32_16X16X128_F8F6F4_WORKS
#define BUILTIN_AMDGCN_MFMA_SCALE_F32_16X16X128_F8F6F4_WORKS 0
#endif
template <index_t MPerWave, index_t NPerWave, index_t OpselA, index_t OpselB>
struct intrin_mfma_scale_f32_16x16x128f8f6f4;
template <>
struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
template <index_t OpselA, index_t OpselB>
struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16, OpselA, OpselB>
{
#define V_MFMA_SCALE_F32_16X16X128_F8F6F4(OPF_F8F6F4_CTRL_A, \
OPF_F8F6F4_CTRL_B, \
F8F6F4_VEC_TYPE_A, \
F8F6F4_VEC_TYPE_B, \
OPSEL_A_L, \
OPSEL_A_H, \
OPSEL_B_L, \
OPSEL_B_H) \
if constexpr((OpselA == 1 * OPSEL_A_L + 2 * OPSEL_A_H) && \
(OpselB == 1 * OPSEL_B_L + 2 * OPSEL_B_H)) \
asm volatile("v_mfma_scale_f32_16x16x128_f8f6f4 %0, %1, %2, %3, %4, %5 " \
"op_sel:[" #OPSEL_A_L "," #OPSEL_A_H "] " \
"op_sel_hi:[" #OPSEL_B_L "," #OPSEL_B_H "] " \
"cbsz:" #OPF_F8F6F4_CTRL_A " blgp:" #OPF_F8F6F4_CTRL_B \
: "+v"(reg_c.template AsType<float4_t>()(Number<0>{})) \
: "v"(bit_cast<F8F6F4_VEC_TYPE_A>(reg_a)), \
"v"(bit_cast<F8F6F4_VEC_TYPE_B>(reg_b)), \
"v"(reg_c.template AsType<float4_t>()[Number<0>{}]), \
"v"(scale_a), \
"v"(scale_b))
#define BOOL4_CASES(F) \
do \
{ \
F(0, 0, 0, 0); \
F(0, 0, 0, 1); \
F(0, 0, 1, 0); \
F(0, 0, 1, 1); \
F(0, 1, 0, 0); \
F(0, 1, 0, 1); \
F(0, 1, 1, 0); \
F(0, 1, 1, 1); \
F(1, 0, 0, 0); \
F(1, 0, 0, 1); \
F(1, 0, 1, 0); \
F(1, 0, 1, 1); \
F(1, 1, 0, 0); \
F(1, 1, 0, 1); \
F(1, 1, 1, 0); \
F(1, 1, 1, 1); \
} while(0)
template <class FloatC>
__device__ static void Run(const f8x32_t& reg_a,
const int32_t& scale_a,
@@ -896,18 +943,24 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
FloatC& reg_c)
{
#if defined(__gfx950__)
#if BUILTIN_AMDGCN_MFMA_SCALE_F32_16X16X128_F8F6F4_WORKS
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
reg_a,
reg_b,
reg_c.template AsType<float4_t>()[Number<0>{}],
0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
0, // blgp
0, // OPSEL
0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
0, // blgp
OpselA, // OPSEL
scale_a,
0, // OPSEL
OpselB, // OPSEL
scale_b);
#else
#define f8_cases(...) V_MFMA_SCALE_F32_16X16X128_F8F6F4(0, 0, int32x8_t, int32x8_t, __VA_ARGS__)
BOOL4_CASES(f8_cases);
#undef f8_cases
#endif
#else
ignore = reg_a;
ignore = scale_a;
@@ -925,18 +978,23 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
FloatC& reg_c)
{
#if defined(__gfx950__)
#if BUILTIN_AMDGCN_MFMA_SCALE_F32_16X16X128_F8F6F4_WORKS
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
reg_a,
reg_b,
reg_c.template AsType<float4_t>()[Number<0>{}],
1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
1, // blgp
0, // OPSEL
1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
1, // blgp
OpselA, // OPSEL
scale_a,
0, // OPSEL
OpselB, // OPSEL
scale_b);
#else
#define bf8_cases(...) V_MFMA_SCALE_F32_16X16X128_F8F6F4(1, 1, int32x8_t, int32x8_t, __VA_ARGS__)
BOOL4_CASES(bf8_cases);
#endif
#else
ignore = reg_a;
ignore = scale_a;
@@ -954,18 +1012,24 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
FloatC& reg_c)
{
#if defined(__gfx950__)
#if BUILTIN_AMDGCN_MFMA_SCALE_F32_16X16X128_F8F6F4_WORKS
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
reg_a,
reg_b,
reg_c.template AsType<float4_t>()[Number<0>{}],
0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
1, // blgp
0, // OPSEL
0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
1, // blgp
OpselA, // OPSEL
scale_a,
0, // OPSEL
OpselB, // OPSEL
scale_b);
#else
#define f8bf8_cases(...) V_MFMA_SCALE_F32_16X16X128_F8F6F4(0, 1, int32x8_t, int32x8_t, __VA_ARGS__)
BOOL4_CASES(f8bf8_cases);
#undef f8bf8_cases
#endif
#else
ignore = reg_a;
ignore = scale_a;
@@ -983,18 +1047,24 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
FloatC& reg_c)
{
#if defined(__gfx950__)
#if BUILTIN_AMDGCN_MFMA_SCALE_F32_16X16X128_F8F6F4_WORKS
// https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
reg_a,
reg_b,
reg_c.template AsType<float4_t>()[Number<0>{}],
1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
0, // blgp
0, // OPSEL
1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
0, // blgp
OpselA, // OPSEL
scale_a,
0, // OPSEL
OpselB, // OPSEL
scale_b);
#else
#define bf8f8_cases(...) V_MFMA_SCALE_F32_16X16X128_F8F6F4(1, 0, int32x8_t, int32x8_t, __VA_ARGS__)
BOOL4_CASES(bf8f8_cases);
#undef bf8f8_cases
#endif
#else
ignore = reg_a;
ignore = scale_a;
@@ -1022,11 +1092,11 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
reg_c.template AsType<float4_t>()[Number<0>{}],
2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
2, // blgp
0, // OPSEL
2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
2, // blgp
OpselA, // OPSEL
scale_a,
0, // OPSEL
OpselB, // OPSEL
scale_b);
#else
ignore = reg_a;
@@ -1055,11 +1125,11 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
reg_c.template AsType<float4_t>()[Number<0>{}],
3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
3, // blgp
0, // OPSEL
3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
3, // blgp
OpselA, // OPSEL
scale_a,
0, // OPSEL
OpselB, // OPSEL
scale_b);
#else
ignore = reg_a;
@@ -1071,29 +1141,43 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
}
template <class FloatC>
__device__ static void Run(const f4x32_t& reg_a,
const int32_t scale_a,
const f4x32_t& reg_b,
const int32_t scale_b,
FloatC& reg_c)
__device__ static void
Run(const f4x32_t& reg_a, // misalignment between pk_f4_t, 32 and f4_t, 32
const int32_t scale_a,
const f4x32_t& reg_b,
const int32_t scale_b,
FloatC& reg_c)
{
#if 0
if(get_thread_local_1d_id()){
printf("Tid: %03d, Scale A: %08x, Scale B: %08x, OpSelA: %d, OpSelB: %d\n",
get_thread_local_1d_id(),
*reinterpret_cast<const uint32_t*>(&scale_a), *reinterpret_cast<const
uint32_t*>(&scale_b),
OpselA, OpselB);
}
#endif
#if defined(__gfx950__)
#if BUILTIN_AMDGCN_MFMA_SCALE_F32_16X16X128_F8F6F4_WORKS
int32x4_t arg_a = bit_cast<int32x4_t>(reg_a);
int32x4_t arg_b = bit_cast<int32x4_t>(reg_b);
using arg_type = int32x8_t;
using arg_type = int32x8_t;
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
reg_c.template AsType<float4_t>()[Number<0>{}],
4, // cbsz
4, // blgp
0, // OPSEL
4, // cbsz
4, // blgp
OpselA, // OPSEL
scale_a,
0, // OPSEL
OpselB, // OPSEL
scale_b);
#else
#define f4_cases(...) V_MFMA_SCALE_F32_16X16X128_F8F6F4(4, 4, int32x4_t, int32x4_t, __VA_ARGS__)
BOOL4_CASES(f4_cases);
#undef f4_cases
#endif
#else
ignore = reg_a;
ignore = scale_a;
@@ -1102,7 +1186,9 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
ignore = reg_c;
#endif
}
};
#undef BOOL4_CASES
#undef V_MFMA_SCALE_F32_16X16X128_F8F6F4
}; // namespace ck
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_16x16x128f8f6f4;

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -71,7 +71,8 @@ template <index_t BlockSize,
index_t NRepeat,
index_t MPerXDL,
index_t NPerXDL,
index_t KPerXDL>
index_t KPerXDL,
bool IsF4F6 = false>
struct BlockwiseGemmXdlops_pipeline_hotloop_inst
{
static constexpr index_t WaveSize = 64;
@@ -99,14 +100,16 @@ struct BlockwiseGemmXdlops_pipeline_hotloop_inst
static constexpr index_t C_MFMA_Inst_Num =
MPerBlock * NPerBlock * KPerBlock / (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
static constexpr index_t C_MFMA_SpeedUp = IsF4F6 ? 2 : 1;
static constexpr index_t C_MFMA_Inst_Cycle = []() {
if constexpr(NPerXDL == 16)
{
return KPerXDL == 128 ? 32 : 16;
return KPerXDL == 128 ? 32 / C_MFMA_SpeedUp : 16 / C_MFMA_SpeedUp;
}
else if constexpr(NPerXDL == 32)
{
return KPerXDL == 64 ? 64 : 32;
return KPerXDL == 64 ? 64 / C_MFMA_SpeedUp : 32 / C_MFMA_SpeedUp;
}
}();
@@ -123,7 +126,7 @@ struct BlockwiseGemmXdlops_pipeline_hotloop_inst
KPerXDL);
printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: "
"%d, %d\n C MFMA inst: %d\n"
"%d, %d\n C MFMA inst: %d C MFMA cycle: %d\n"
"A/B LDS read width: %d, %d, A/B LDS write width: %d, %d, A/B buffer load width: "
"%d/ %d\n",
A_Buffer_Load_Inst_Num,
@@ -133,6 +136,7 @@ struct BlockwiseGemmXdlops_pipeline_hotloop_inst
A_LDS_Read_Inst_Num,
B_LDS_Read_Inst_Num,
C_MFMA_Inst_Num,
C_MFMA_Inst_Cycle,
A_LDS_Read_Width,
B_LDS_Read_Width,
ALDSWriteWidth,

View File

@@ -43,8 +43,8 @@ struct f4x2_pk_t
using type = uint8_t;
type data;
__host__ __device__ f4x2_pk_t() : data{type{}} {}
__host__ __device__ f4x2_pk_t(type init) : data{init} {}
__host__ __device__ constexpr f4x2_pk_t() : data{type{}} {}
__host__ __device__ constexpr f4x2_pk_t(const type init) : data{init} {}
template <index_t I>
__host__ __device__ inline type unpack(Number<I>) const
@@ -165,6 +165,17 @@ inline constexpr bool is_native_type()
is_same<T, f8_fnuz_t>::value || is_same<T, bf8_fnuz_t>::value || is_same<T, bool>::value;
}
template <typename T>
struct is_f8f6f4
{
static constexpr bool value =
is_same_v<T, f8_t> || is_same_v<T, bf8_t> || is_same_v<T, f6_t> || is_same_v<T, bf6_t> ||
is_same_v<T, f6x16_pk_t> || is_same_v<T, f6x32_pk_t> || is_same_v<T, bf6x16_pk_t> ||
is_same_v<T, bf6x32_pk_t> || is_same_v<T, f4_t> || is_same_v<T, f4x2_pk_t>;
};
template <typename T>
inline constexpr bool is_f8f6f4_v = is_f8f6f4<T>::value;
// scalar_type
template <typename TV>
struct scalar_type;
@@ -303,105 +314,87 @@ struct scalar_type<bool>
static constexpr index_t vector_size = 1;
};
// Default behavior for types that do not need special handling
template <typename T>
struct packed_type
{
using type = T;
static constexpr index_t packed_size = 1; // number of packed elements
};
template <>
struct packed_type<int4_t>
{
using type = pk_i4_t;
static constexpr index_t packed_size = 2; // number of packed elements
};
template <>
struct packed_type<f4_t>
{
using type = f4x2_pk_t;
static constexpr index_t packed_size = 2; // number of packed elements
};
template <>
struct packed_type<f6_t>
{
using type = f6x32_pk_t;
static constexpr index_t packed_size = f6x32_pk_t::packed_size; // number of packed elements
};
template <>
struct packed_type<bf6_t>
{
using type = bf6x32_pk_t;
static constexpr index_t packed_size = bf6x32_pk_t::packed_size; // number of packed elements
};
template <typename T>
using packed_type_t = typename packed_type<T>::type;
// Check if the type has packed type specialization
template <typename T>
inline constexpr bool has_packed_type_v = !is_same_v<packed_type_t<T>, T>;
template <typename T>
struct element_type
struct packed_type_info
{
private:
static constexpr auto get_element_type()
static constexpr auto get_packed_type_info()
{
using U = remove_cvref_t<T>;
if constexpr(is_same_v<U, pk_i4_t>)
return int4_t{};
return ck::Tuple<ck::Number<2>, int4_t>{};
else if constexpr(is_same_v<U, f4x2_pk_t>)
return f4_t{};
return ck::Tuple<ck::Number<2>, f4_t>{};
else if constexpr(is_same_v<U, f6x16_pk_t>)
return f6_t{};
return ck::Tuple<ck::Number<16>, f6_t>{};
else if constexpr(is_same_v<U, bf6x16_pk_t>)
return bf6_t{};
return ck::Tuple<ck::Number<16>, bf6_t>{};
else if constexpr(is_same_v<U, f6x32_pk_t>)
return f6_t{};
return ck::Tuple<ck::Number<32>, f6_t>{};
else if constexpr(is_same_v<U, bf6x32_pk_t>)
return bf6_t{};
return ck::Tuple<ck::Number<32>, bf6_t>{};
else
return ck::Tuple<ck::Number<1>, T>{};
}
public:
using element_type = remove_cvref_t<decltype(get_packed_type_info().At(ck::Number<1>{}))>;
static constexpr auto packed_size =
static_cast<index_t>(get_packed_type_info().At(ck::Number<0>{}));
};
template <typename T>
using element_type_t = typename packed_type_info<T>::element_type;
template <typename T>
inline constexpr index_t packed_size_v = packed_type_info<T>::packed_size;
template <typename T>
inline constexpr bool is_packed_type_v = packed_size_v<T> > 1;
template <typename T, index_t N = 0>
struct packed_type_maker
{
private:
static constexpr auto get_packed_type()
{
using U = remove_cvref_t<T>;
if constexpr(is_same_v<U, int4_t>)
{
static_assert(N == 0 || N == 2, "Packed size N for int4_t must be 2.");
return pk_i4_t{};
}
else if constexpr(is_same_v<U, f4_t>)
{
static_assert(N == 0 || N == 2, "Packed size N for f4_t must be 2.");
return f4x2_pk_t{};
}
else if constexpr(is_same_v<U, f6_t>)
{
static_assert(N == 0 || N == 16 || N == 32, "Packed size N for f6_t must be 16 or 32.");
if constexpr(N == 16)
return f6x16_pk_t{};
else if constexpr(N == 0 || N == 32)
return f6x32_pk_t{};
}
else if constexpr(is_same_v<U, bf6_t>)
{
static_assert(N == 0 || N == 16 || N == 32,
"Packed size N for bf6_t must be 16 or 32.");
if constexpr(N == 16)
return bf6x16_pk_t{};
else if constexpr(N == 0 || N == 32)
return bf6x32_pk_t{};
}
else
return T{};
}
public:
using type = decltype(get_element_type());
};
template <typename T>
using element_type_t = typename element_type<T>::type;
template <typename T>
inline constexpr bool is_packed_type_v =
has_packed_type_v<element_type_t<T>>&& is_same_v<T, packed_type_t<element_type_t<T>>>;
template <typename T>
struct packed_size
{
private:
static constexpr auto get_packed_size()
{
using U = remove_cvref_t<T>;
if constexpr(is_packed_type_v<U>)
return Number<packed_type<element_type_t<U>>::packed_size>{};
else
return Number<packed_type<U>::packed_size>{};
}
public:
using type = decltype(get_packed_size());
static constexpr auto value = get_packed_size();
using packed_type = remove_cvref_t<decltype(get_packed_type())>;
};
template <typename T>
using packed_size_t = typename packed_size<T>::type;
template <typename T>
inline constexpr index_t packed_size_v = packed_size<T>::value;
template <typename T, index_t N = 0>
using packed_type_t = typename packed_type_maker<T, N>::packed_type;
#if defined(_WIN32)
using int64_t = long long;

View File

@@ -1330,6 +1330,12 @@ struct nnvb_data_t_selector<pk_i4_t>
using type = pk_i4_t::type;
};
template <>
struct nnvb_data_t_selector<f4x2_pk_t>
{
using type = f4x2_pk_t::type;
};
template <typename T, index_t N>
struct non_native_vector_base<
T,
@@ -2222,6 +2228,7 @@ using f6x32_t = typename vector_type<f6x32_pk_t, 1>::type;
using bf6x16_t = typename vector_type<bf6x16_pk_t, 1>::type;
using bf6x32_t = typename vector_type<bf6x32_pk_t, 1>::type;
using e8m0x4_bexp_t = typename vector_type<e8m0_bexp_t, 4>::type;
// pack int4
using pk_i4x2_t = typename vector_type<pk_i4_t, 2>::type;
using pk_i4x4_t = typename vector_type<pk_i4_t, 4>::type;

View File

@@ -1,10 +1,11 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/functional.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/utility/tuple.hpp"
namespace ck {
@@ -70,4 +71,44 @@ struct static_for<0, N, 1> : detail::make_applier<N>
using detail::make_applier<N>::operator();
};
template <typename... Is>
struct static_for_range
{
template <typename F>
__host__ __device__ constexpr void operator()(F f) const
{
// tweak -fbracket-depth if compilation fails. Clang default limit is 256
(f(Is{}), ...);
}
};
template <typename... Ts>
struct static_for_product;
template <typename... Is>
struct static_for_product<Tuple<Is...>> : public static_for_range<Is...>
{
};
template <typename... Is, typename... Rest>
struct static_for_product<Tuple<Is...>, Rest...>
{
template <typename F>
__host__ __device__ constexpr void operator()(F f) const
{
static_for_product<Tuple<Is...>>{}([&](auto i0) { //
static_for_product<Rest...>{}([&](auto... is) { //
f(i0, is...);
});
});
}
};
struct identity
{
template <typename T>
__host__ __device__ constexpr T&& operator()(T&& arg) const noexcept
{
return forward<T>(arg);
}
};
} // namespace ck

View File

@@ -5,14 +5,22 @@
namespace ck {
template <auto v>
struct constant
{
using value_type = decltype(v);
using type = constant; // using injected-class-name
static constexpr value_type value = v;
__host__ __device__ constexpr operator value_type() const noexcept { return value; }
__host__ __device__ constexpr value_type operator()() const noexcept { return value; }
};
template <class T, T v>
struct integral_constant
struct integral_constant : constant<v>
{
static constexpr T value = v;
typedef T value_type;
typedef integral_constant type;
__host__ __device__ constexpr operator value_type() const noexcept { return value; }
__host__ __device__ constexpr value_type operator()() const noexcept { return value; }
};
template <typename TX, TX X, typename TY, TY Y>

View File

@@ -1586,6 +1586,11 @@ inline __host__ __device__ f4x2_t type_convert<f4x2_t, float2_t>(float2_t x)
return f4_convert_rne(x);
#endif
}
template <>
inline __host__ __device__ f4x2_pk_t type_convert<f4x2_pk_t, float2_t>(float2_t x)
{
return static_cast<f4x2_pk_t>(type_convert<f4x2_t>(x));
}
// convert vector of 32 fp32 to vector of 32 fp4
template <>

View File

@@ -112,7 +112,7 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return a_lds_block_desc;
return a_lds_block_desc;
#endif
}

View File

@@ -77,33 +77,34 @@ struct ReferenceMXGemm : public device::BaseOperator
ComputeTypeA,
ComputeTypeB>;
Tensor<ComputeTypeA> a_m_k_scaled(arg.a_m_k_.mDesc);
Tensor<ComputeTypeB> b_k_n_scaled(arg.b_k_n_.mDesc);
const ck::index_t M = arg.a_m_k_.mDesc.GetLengths()[0];
const ck::index_t N = arg.b_k_n_.mDesc.GetLengths()[1];
assert(arg.a_m_k_.mDesc.GetLengths()[1] == arg.b_k_n_.mDesc.GetLengths()[0]);
const ck::index_t K = arg.a_m_k_.mDesc.GetLengths()[1];
const ck::index_t SCALE_BLOCK = K / arg.a_m_kblock_scales_.mDesc.GetLengths()[1];
Tensor<ComputeTypeA> a_m_k_scaled(HostTensorDescriptor({M, K}, {K, 1}));
Tensor<ComputeTypeB> b_k_n_scaled(HostTensorDescriptor({K, N}, {1, K}));
// printf("K: %d\n", K);
const auto M = arg.a_m_k_.mDesc.GetLengths()[0];
const auto N = arg.b_k_n_.mDesc.GetLengths()[1];
const auto K = arg.a_m_k_.mDesc.GetLengths()[1];
const auto SCALE_BLOCK = K / arg.a_m_kblock_scales_.mDesc.GetLengths()[1];
for(size_t m = 0; m < M; m++)
for(int m = 0; m < M; m++)
{
for(size_t k = 0; k < K; k++)
for(int k = 0; k < K; k++)
{
if constexpr(is_same_v<ADataType, f4x2_pk_t>)
{
// TODO: add support for ColMajor layout as well
if(k % 2 == 1)
a_m_k_scaled(m, k) =
type_convert<ComputeTypeA>(
f4_t(arg.a_m_k_(m, k).template unpack<>(Number<1>{}))) *
type_convert<ComputeTypeA>(
arg.a_m_kblock_scales_(m, k / SCALE_BLOCK));
else
a_m_k_scaled(m, k) =
type_convert<ComputeTypeA>(
f4_t(arg.a_m_k_(m, k).template unpack<>(Number<0>{}))) *
type_convert<ComputeTypeA>(
arg.a_m_kblock_scales_(m, k / SCALE_BLOCK));
{
continue;
}
// TODO: add support for ColMajor layout as well
auto a_pack = arg.a_m_k_(m, k);
auto a_scale =
type_convert<ComputeTypeA>(arg.a_m_kblock_scales_(m, k / SCALE_BLOCK));
auto a_f4_lo = f4_t(a_pack.template unpack<>(Number<0>{}));
auto a_f4_hi = f4_t(a_pack.template unpack<>(Number<1>{}));
a_m_k_scaled(m, k) = type_convert<ComputeTypeA>(a_f4_lo) * a_scale;
a_m_k_scaled(m, k + 1) = type_convert<ComputeTypeA>(a_f4_hi) * a_scale;
}
else if constexpr(is_same_v<ADataType, f6x16_pk_t> ||
is_same_v<ADataType, bf6x16_pk_t> ||
@@ -124,25 +125,24 @@ struct ReferenceMXGemm : public device::BaseOperator
}
}
for(size_t n = 0; n < N; n++)
for(int n = 0; n < N; n++)
{
for(size_t k = 0; k < K; k++)
for(int k = 0; k < K; k++)
{
if constexpr(is_same_v<BDataType, f4x2_pk_t>)
{
// TODO: add support for RowMajor layout as well
if(k % 2 == 1)
b_k_n_scaled(k, n) =
type_convert<ComputeTypeB>(
f4_t(arg.b_k_n_(k, n).template unpack<>(Number<1>{}))) *
type_convert<ComputeTypeB>(
arg.b_kblock_n_scales_(k / SCALE_BLOCK, n));
else
b_k_n_scaled(k, n) =
type_convert<ComputeTypeB>(
f4_t(arg.b_k_n_(k, n).template unpack<>(Number<0>{}))) *
type_convert<ComputeTypeB>(
arg.b_kblock_n_scales_(k / SCALE_BLOCK, n));
{
continue;
}
auto b_pack = arg.b_k_n_(k, n);
auto b_scale =
type_convert<ComputeTypeB>(arg.b_kblock_n_scales_(k / SCALE_BLOCK, n));
auto b_f4_lo = f4_t(b_pack.template unpack<>(Number<0>{}));
auto b_f4_hi = f4_t(b_pack.template unpack<>(Number<1>{}));
b_k_n_scaled(k, n) = type_convert<ComputeTypeB>(b_f4_lo) * b_scale;
b_k_n_scaled(k + 1, n) = type_convert<ComputeTypeB>(b_f4_hi) * b_scale;
}
else if constexpr(is_same_v<BDataType, f6x16_pk_t> ||
is_same_v<BDataType, bf6x16_pk_t> ||

View File

@@ -23,6 +23,10 @@ using I32 = int32_t;
using F8 = ck::f8_t;
using BF8 = ck::bf8_t;
using I4 = ck::pk_i4_t;
using F4 = ck::f4x2_pk_t;
using E8M0 = ck::e8m0_bexp_t;
using E8M0PK = int32_t;
using Empty_Tuple = ck::Tuple<>;
@@ -42,8 +46,9 @@ using BF16_Tuple = ck::Tuple<BF16>;
using F32_F32_Tuple = ck::Tuple<F32, F32>;
// GEMM layout
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using MFMA = ck::tensor_layout::gemm::MFMA;
using Row_Tuple = ck::Tuple<Row>;
using Row_Row_Tuple = ck::Tuple<Row, Row>;

View File

@@ -22,9 +22,9 @@ void add_device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_default_instances(
Col,
Row,
F8,
e8m0_bexp_t,
E8M0PK,
F8,
e8m0_bexp_t,
E8M0PK,
F16,
32,
PassThrough,
@@ -36,23 +36,37 @@ void add_device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instances(
Col,
Row,
F8,
e8m0_bexp_t,
E8M0PK,
F8,
e8m0_bexp_t,
E8M0PK,
BF16,
32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_default_instances(
std::vector<std::unique_ptr<DeviceGemmMX<Row,
Col,
Row,
F4,
I32,
F4,
I32,
F16,
32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
void add_device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instances(
std::vector<std::unique_ptr<DeviceGemmMX<Row,
Row,
Row,
BF8,
e8m0_bexp_t,
E8M0PK,
F8,
e8m0_bexp_t,
E8M0PK,
F16,
32,
PassThrough,
@@ -64,9 +78,9 @@ void add_device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instances(
Col,
Row,
F8,
e8m0_bexp_t,
E8M0PK,
F8,
e8m0_bexp_t,
E8M0PK,
BF16,
32,
PassThrough,
@@ -94,7 +108,8 @@ struct DeviceOperationInstanceFactory<
ScaleBlockSize,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>>
ck::tensor_operation::element_wise::PassThrough>,
enable_if_t<!is_same_v<BLayout, MFMA>>> // non-weight-pre-shuffle
{
using DeviceOp = DeviceGemmMX<ALayout,
BLayout,
@@ -127,6 +142,11 @@ struct DeviceOperationInstanceFactory<
add_device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instances(op_ptrs);
}
else if constexpr(is_same_v<ADataType, F4> && is_same_v<BDataType, F4> &&
is_same_v<CDataType, F16>)
{
add_device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_default_instances(op_ptrs);
}
}
else if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, Row> &&
is_same_v<CLayout, Row>)
@@ -153,6 +173,73 @@ struct DeviceOperationInstanceFactory<
}
};
void add_device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn_default_instances(
std::vector<std::unique_ptr<DeviceGemmMX<Row,
MFMA,
Row,
F4,
I32,
F4,
I32,
F16,
32,
PassThrough,
PassThrough,
PassThrough>>>& instances);
template <typename ADataType,
typename AScaleDataType,
typename BDataType,
typename BScaleDataType,
typename CDataType,
index_t ScaleBlockSize,
typename ALayout,
typename BLayout,
typename CLayout>
struct DeviceOperationInstanceFactory<
ck::tensor_operation::device::DeviceGemmMX<ALayout,
BLayout,
CLayout,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
CDataType,
ScaleBlockSize,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>,
enable_if_t<is_same_v<BLayout, MFMA>>>
{
using DeviceOp = DeviceGemmMX<ALayout,
BLayout,
CLayout,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
CDataType,
ScaleBlockSize,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough>;
static auto GetInstances()
{
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
if constexpr(is_same_v<ALayout, Row> && is_same_v<BLayout, MFMA> && is_same_v<CLayout, Row>)
{
if constexpr(is_same_v<ADataType, F4> && is_same_v<BDataType, F4> &&
is_same_v<CDataType, F16>)
{
add_device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn_default_instances(op_ptrs);
}
}
return op_ptrs;
}
};
} // namespace instance
} // namespace device
} // namespace tensor_operation

View File

@@ -34,19 +34,19 @@ using device_gemm_xdl_c_shuffle_lds_direct_load_f16_f16_f16_mk_nk_mn_instances =
// ##################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ##################################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 64, 64, 64, 16, 16, 32, 32, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 2, 0, S<4, 8, 8>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 16, 32, 32, 8, 8, 16, 16, 1, 1, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 0, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 16, 32, 64, 16, 16, 16, 16, 1, 1, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 16, 32, 64, 16, 16, 16, 16, 1, 1, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 256, 64, 64, 64, 16, 16, 16, 16, 2, 2, S<4, 8, 8>, S<1, 0, 2>, 2, 2, 0, S<4, 8, 8>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 1, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 1, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 128, 16, 32, 64, 16, 16, 16, 16, 1, 1, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 128, 16, 32, 64, 16, 16, 16, 16, 1, 1, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 128, 16, 32, 32, 8, 8, 16, 16, 1, 1, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 0, S<2, 16, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 128, 32, 16, 64, 16, 16, 16, 16, 1, 1, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 128, 32, 16, 64, 16, 16, 16, 16, 1, 1, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, S<2, 8, 8>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 64, 16, 16, 128, 32, 32, 16, 16, 1, 1, S<1, 4, 16>, S<1, 0, 2>, 2, 2, 0, S<1, 4, 16>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 64, 16, 16, 128, 32, 32, 16, 16, 1, 1, S<1, 4, 16>, S<1, 0, 2>, 2, 2, 0, S<1, 4, 16>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 64, 64, 64, 8, 8, 32, 32, 1, 1, S<8, 8, 4>, S<1, 0, 2>, 2, 2, 0, S<8, 8, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 16, 32, 32, 8, 8, 16, 16, 1, 1, S<4, 8, 4>, S<1, 0, 2>, 2, 2, 0, S<4, 8, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 4, 4>, S<1, 0, 2>, 2, 2, 0, S<8, 4, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 4, 4>, S<1, 0, 2>, 2, 2, 0, S<8, 4, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 256, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 8, 4>, S<1, 0, 2>, 2, 2, 0, S<8, 8, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 128, 64, 32, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 4>, S<1, 0, 2>, 2, 2, 0, S<4, 8, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 4, 4>, S<1, 0, 2>, 2, 2, 0, S<8, 4, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 4, 4>, S<1, 0, 2>, 2, 2, 0, S<8, 4, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 128, 16, 32, 32, 8, 8, 16, 16, 1, 1, S<4, 8, 4>, S<1, 0, 2>, 2, 2, 0, S<4, 8, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 128, 32, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 4, 4>, S<1, 0, 2>, 2, 2, 0, S<8, 4, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 128, 32, 16, 64, 8, 8, 16, 16, 1, 1, S<8, 4, 4>, S<1, 0, 2>, 2, 2, 0, S<8, 4, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 64, 16, 16, 128, 8, 8, 16, 16, 1, 1, S<16, 1, 4>, S<1, 0, 2>, 2, 2, 0, S<16, 1, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F16, F16, F16, F32, F32, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 64, 16, 16, 128, 8, 8, 16, 16, 1, 1, S<16, 1, 4>, S<1, 0, 2>, 2, 2, 0, S<16, 1, 4>, S<1, 0, 2>, 2, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>
// clang-format on
>;

View File

@@ -32,8 +32,8 @@ using device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_kn_mn_instances =
// ##################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ##################################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 32, 32, 32, 4, 4, 32, 32, 1, 1, S<1, 16, 4>, S<0, 1, 2>, 1, 1, 0, S<1, 16, 4>, S<0, 1, 2>, 1, 1, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Col, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 32, 32, 32, 4, 4, 32, 32, 1, 1, S<1, 16, 4>, S<0, 1, 2>, 1, 1, 0, S<1, 16, 4>, S<0, 1, 2>, 1, 1, 0, 1, 1, S<1, 8, 1, 8>, 4>
// clang-format on
>;

View File

@@ -32,8 +32,8 @@ using device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_km_nk_mn_instances =
// ##################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ##################################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 32, 32, 32, 4, 4, 32, 32, 1, 1, S<1, 16, 4>, S<0, 1, 2>, 1, 1, 0, S<8, 2, 4>, S<1, 0, 2>, 2, 1, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Col, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 32, 32, 32, 4, 4, 32, 32, 1, 1, S<1, 16, 4>, S<0, 1, 2>, 1, 1, 0, S<8, 2, 4>, S<1, 0, 2>, 2, 1, 0, 1, 1, S<1, 8, 1, 8>, 4>
// clang-format on
>;

View File

@@ -31,8 +31,8 @@ using device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_kn_mn_instances =
// ##################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ##################################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<4, 8, 8>, S<0, 2, 1>, 1, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 32, 32, 32, 4, 4, 32, 32, 1, 1, S<8, 2, 4>, S<1, 0, 2>, 2, 1, 0, S<1, 16, 4>, S<0, 1, 2>, 1, 1, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Row, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 32, 32, 32, 4, 4, 32, 32, 1, 1, S<8, 2, 4>, S<1, 0, 2>, 2, 1, 0, S<1, 16, 4>, S<0, 1, 2>, 1, 1, 0, 1, 1, S<1, 8, 1, 8>, 4>
// clang-format on
>;

View File

@@ -32,8 +32,8 @@ using device_gemm_xdl_c_shuffle_lds_direct_load_f32_f32_f32_mk_nk_mn_instances =
// ##################################| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ##################################| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ##################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, 1, 1, S<1, 8, 1, 8>, 4>
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 0, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemm_Xdl_CShuffle_LdsDirectLoad< Row, Col, Row, F32, F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 256, 64, 64, 32, 8, 8, 32, 32, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 0, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 0, 1, 1, S<1, 8, 1, 8>, 4>
// clang-format on
>;

View File

@@ -6,6 +6,8 @@ list(APPEND GEMM_MX_INSTANCES
device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instance.cpp
device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instance.cpp
device_gemm_mx_xdl_bf8_f8_f16/device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instance.cpp
device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_default_instance.cpp
device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn_default_instance.cpp
)
@@ -13,6 +15,8 @@ set_source_files_properties(device_gemm_mx_xdl_f8_f8_f16/device_gemm_mx_xdl_f8_f
set_source_files_properties(device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_mx_xdl_f8_f8_bf16/device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_mx_xdl_bf8_f8_f16/device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
set_source_files_properties(device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1")
add_instance_library(device_gemm_mx_instance ${GEMM_MX_INSTANCES})

View File

@@ -13,12 +13,13 @@ namespace tensor_operation {
namespace device {
namespace instance {
using F8 = f8_t;
using BF8 = bf8_t;
using F16 = half_t;
using BF16 = bhalf_t;
using F32 = float;
using E8M0 = ck::e8m0_bexp_t;
using F8 = f8_t;
using BF8 = bf8_t;
using F16 = half_t;
using BF16 = bhalf_t;
using F32 = float;
using E8M0 = ck::e8m0_bexp_t;
using E8M0PK = int32_t;
using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;
@@ -40,17 +41,19 @@ static constexpr auto ScaleBlockSize = 32;
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
using device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_instances = std::tuple<
#if 0 // TODO: Fix RRR
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData|AScale|BData|BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 64, 16, 128, 16, 4, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 256, 256, 16, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, false, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 64, 256, 16, 4, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 16, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 16, 32, 512, 16, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>
//#########################| ALayout| BLayout| CLayout|AData| AScale|BData| BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0PK, F8, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 64, 16, 128, 16, 4, 16, 16, 2, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0PK, F8, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 256, 256, 16, 4, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, false, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0PK, F8, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 64, 256, 16, 4, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0PK, F8, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 16, 4, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Row, Row, BF8, E8M0PK, F8, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 16, 32, 512, 16, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<64, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>
// clang-format on
#endif
>;
} // namespace instance

View File

@@ -13,9 +13,9 @@ void add_device_gemm_mx_xdl_bf8_f8_f16_mk_kn_mn_default_instances(
Row,
Row,
BF8,
E8M0,
E8M0PK,
F8,
E8M0,
E8M0PK,
F16,
32,
PassThrough,

View File

@@ -0,0 +1,73 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F4 = f4x2_pk_t;
using F16 = half_t;
using F32 = float;
using E8M0 = ck::e8m0_bexp_t;
using E8M0PK = int32_t;
using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;
using MFMA = tensor_layout::gemm::MFMA;
template <index_t... Is>
using S = Sequence<Is...>;
using PassThrough = element_wise::PassThrough;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
static constexpr auto GemmMPadding = GemmSpecialization::MPadding;
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto ScaleBlockSize = 32;
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
using device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn_instances = std::tuple<
// clang-format off
//#####################| ALayout| BLayout| CLayout|AData| AScale|BData| BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#####################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#####################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
// DeviceGemmMX_Xdl_CShuffleV3<Row, MFMA, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 32, 128, 128, 16, 16, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
// DeviceGemmMX_Xdl_CShuffleV3<Row, MFMA, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 32, 256, 128, 16, 16, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
// DeviceGemmMX_Xdl_CShuffleV3<Row, MFMA, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 32, 384, 128, 16, 16, 16, 16, 2, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
// DeviceGemmMX_Xdl_CShuffleV3<Row, MFMA, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 32, 512, 128, 16, 16, 16, 16, 2, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3<Row, MFMA, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 128, 128, 16, 16, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3<Row, MFMA, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 256, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3<Row, MFMA, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 384, 128, 16, 16, 16, 16, 4, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3<Row, MFMA, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 512, 128, 16, 16, 16, 16, 4, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3<Row, MFMA, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 96, 128, 128, 16, 16, 16, 16, 6, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3<Row, MFMA, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 96, 256, 128, 16, 16, 16, 16, 6, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3<Row, MFMA, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 96, 384, 128, 16, 16, 16, 16, 6, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3<Row, MFMA, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 96, 512, 128, 16, 16, 16, 16, 6, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3<Row, MFMA, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 16, 16, 16, 16, 8, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3<Row, MFMA, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 256, 128, 16, 16, 16, 16, 8, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3<Row, MFMA, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 384, 128, 16, 16, 16, 16, 8, 6, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3<Row, MFMA, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 512, 128, 16, 16, 16, 16, 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,32 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn_default_instances(
std::vector<std::unique_ptr<DeviceGemmMX<Row,
MFMA,
Row,
F4,
E8M0PK,
F4,
E8M0PK,
F16,
32,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances, device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn_instances<Intrawave, GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,65 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
using F4 = f4x2_pk_t;
using F16 = half_t;
using F32 = float;
using E8M0 = ck::e8m0_bexp_t;
using E8M0PK = int32_t;
using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;
template <index_t... Is>
using S = Sequence<Is...>;
using PassThrough = element_wise::PassThrough;
static constexpr auto GemmDefault = GemmSpecialization::Default;
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
static constexpr auto GemmMPadding = GemmSpecialization::MPadding;
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
static constexpr auto ScaleBlockSize = 32;
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
using device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_instances = std::tuple<
// clang-format off
//#############################| ALayout| BLayout| CLayout|AData| AScale|BData| BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#############################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#############################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 32, 128, 128, 16, 16, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 32, 256, 128, 16, 16, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 128, 128, 16, 16, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 256, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 96, 128, 128, 16, 16, 16, 16, 6, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 96, 256, 128, 16, 16, 16, 16, 6, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 256, 128, 16, 16, 16, 16, 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 256, 128, 16, 16, 16, 16, 4, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 128, 128, 16, 16, 16, 16, 8, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 32, 32, 128, 16, 16, 16, 16, 2, 2, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 16, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>
// clang-format on
>;
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,32 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace instance {
void add_device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_default_instances(
std::vector<std::unique_ptr<DeviceGemmMX<Row,
Col,
Row,
F4,
E8M0PK,
F4,
E8M0PK,
F16,
32,
PassThrough,
PassThrough,
PassThrough>>>& instances)
{
add_device_operation_instances(
instances, device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_instances<Intrawave, GemmDefault>{});
}
} // namespace instance
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -13,11 +13,12 @@ namespace tensor_operation {
namespace device {
namespace instance {
using F8 = f8_t;
using F16 = half_t;
using BF16 = bhalf_t;
using F32 = float;
using E8M0 = ck::e8m0_bexp_t;
using F8 = f8_t;
using F16 = half_t;
using BF16 = bhalf_t;
using F32 = float;
using E8M0 = ck::e8m0_bexp_t;
using E8M0PK = int32_t;
using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;
@@ -39,19 +40,21 @@ static constexpr auto ScaleBlockSize = 32;
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
using device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_instances = std::tuple<
#if 0 // TODO: Fix CCR
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData|AScale|BData|BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 4, 16, 32, 32, 2, 2, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 16, 256, 128, 4, 16, 16, 16, 1, 4, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 16, 16, 512, 8, 16, 16, 16, 1, 1, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 256, 128, 8, 16, 16, 16, 8, 8, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 256, 64, 4, 16, 32, 32, 4, 4, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 128, 128, 128, 4, 16, 16, 16, 4, 8, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>
//#########################| ALayout| BLayout| CLayout|AData| AScale|BData| BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0PK, F8, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 4, 16, 32, 32, 2, 2, S<32, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0PK, F8, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 16, 256, 128, 4, 16, 16, 16, 1, 4, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0PK, F8, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0PK, F8, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 16, 16, 512, 8, 16, 16, 16, 1, 1, S<64, 1, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0PK, F8, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 256, 128, 8, 16, 16, 16, 8, 8, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0PK, F8, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 256, 64, 4, 16, 32, 32, 4, 4, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Col, Col, Row, F8, E8M0PK, F8, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 128, 128, 128, 4, 16, 16, 16, 4, 8, S<32, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 4, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>
// clang-format on
#endif
>;
} // namespace instance

View File

@@ -13,9 +13,9 @@ void add_device_gemm_mx_xdl_f8_f8_bf16_km_nk_mn_default_instances(
Col,
Row,
F8,
E8M0,
E8M0PK,
F8,
E8M0,
E8M0PK,
BF16,
32,
PassThrough,

View File

@@ -13,11 +13,12 @@ namespace tensor_operation {
namespace device {
namespace instance {
using F8 = f8_t;
using F16 = half_t;
using BF16 = bhalf_t;
using F32 = float;
using E8M0 = ck::e8m0_bexp_t;
using F8 = f8_t;
using F16 = half_t;
using BF16 = bhalf_t;
using F32 = float;
using E8M0 = ck::e8m0_bexp_t;
using E8M0PK = int32_t;
using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;
@@ -40,15 +41,15 @@ static constexpr auto ScaleBlockSize = 32;
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
using device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_instances = std::tuple<
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData|AScale|BData|BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 128, 16, 128, 16, 16, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 256, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 16, 16, 512, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>
//###########################| ALayout| BLayout| CLayout|AData| AScale|BData| BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//###########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//###########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 256, 16, 16, 16, 16, 4, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 64, 256, 16, 16, 16, 16, 4, 2, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 128, 256, 16, 16, 16, 16, 2, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 128, 32, 256, 16, 16, 16, 16, 4, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 32, 32, 256, 16, 16, 16, 16, 2, 2, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 16, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>
// clang-format on
>;

View File

@@ -13,9 +13,9 @@ void add_device_gemm_mx_xdl_f8_f8_bf16_mk_nk_mn_default_instances(
Col,
Row,
F8,
E8M0,
E8M0PK,
F8,
E8M0,
E8M0PK,
BF16,
32,
PassThrough,

View File

@@ -13,11 +13,12 @@ namespace tensor_operation {
namespace device {
namespace instance {
using F8 = f8_t;
using F16 = half_t;
using BF16 = bhalf_t;
using F32 = float;
using E8M0 = ck::e8m0_bexp_t;
using F8 = f8_t;
using F16 = half_t;
using BF16 = bhalf_t;
using F32 = float;
using E8M0 = ck::e8m0_bexp_t;
using E8M0PK = int32_t;
using Row = tensor_layout::gemm::RowMajor;
using Col = tensor_layout::gemm::ColumnMajor;
@@ -40,15 +41,15 @@ static constexpr auto ScaleBlockSize = 32;
template <BlockGemmPipelineScheduler BlkGemmPipeSched, GemmSpecialization GemmSpec>
using device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_instances = std::tuple<
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData|AScale|BData|BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 128, 16, 128, 16, 16, 16, 16, 4, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, 2, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 256, 16, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 16, 16, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0, F8, E8M0, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 16, 16, 512, 16, 16, 16, 16, 1, 1, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 4>, 4, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>
//###########################| ALayout| BLayout| CLayout|AData| AScale|BData| BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//###########################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//###########################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//###########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 256, 16, 16, 16, 16, 4, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 64, 256, 16, 16, 16, 16, 4, 2, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 128, 256, 16, 16, 16, 16, 2, 4, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16,16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 128, 128, 32, 256, 16, 16, 16, 16, 4, 2, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>,
DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F8, E8M0PK, F8, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 32, 32, 256, 16, 16, 16, 16, 2, 2, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<16, 4, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 16, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>
// clang-format on
>;

View File

@@ -13,9 +13,9 @@ void add_device_gemm_mx_xdl_f8_f8_f16_mk_nk_mn_default_instances(
Col,
Row,
F8,
E8M0,
E8M0PK,
F8,
E8M0,
E8M0PK,
F16,
32,
PassThrough,

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
@@ -37,30 +37,30 @@ using device_gemm_xdl_splitk_lds_direct_load_f16_f16_f16_mk_nk_mn_instances = st
//#######################################| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Specialization| Prefetch| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector|
//#######################################| | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | Wave| Wave| Lengths_KBatch_K0_M_K1| | | PerVector| | Lengths_KBatch_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl|
//#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 16, 128, 4, 16, 16, 16, 1, 2, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 16>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 16, 16, 8, 8, 16, 16, 1, 1, S<1, 1, 16, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 1, 16, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 16, 16, 4, 16, 16, 16, 1, 1, S<1, 1, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 1, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 16, 16, 8, 16, 16, 16, 1, 1, S<1, 1, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 1, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 256, 16, 128, 8, 8, 16, 16, 1, 2, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 16>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 16, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 2, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 8, 2, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 1, 64, 16, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 2, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 8, 2, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 2, 64, 16, 16, 16, 8, 16, 16, 1, 1, S<1, 16, 1, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 16, 1, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 16>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 256, 32, 32, 4, 16, 16, 16, 1, 1, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 256, 16, 64, 8, 16, 16, 16, 1, 1, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 128, 16, 64, 4, 32, 16, 16, 1, 2, S<1, 2, 4, 16>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 2, 4, 16>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 128, 16, 32, 8, 8, 16, 16, 1, 1, S<1, 2, 16, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 2, 16, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<1, 2, 16, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 2, 16, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 64, 16, 16, 4, 32, 16, 16, 1, 1, S<1, 1, 4, 16>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 1, 4, 16>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 256, 64, 16, 4, 16, 16, 16, 1, 1, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 32, 1, 4>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 256, 16, 64, 4, 16, 16, 16, 1, 1, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 64, 16, 16, 8, 16, 16, 16, 1, 1, S<1, 1, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 1, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 256, 128, 128, 8, 8, 32, 32, 2, 2, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 16>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 256, 32, 32, 8, 8, 16, 16, 1, 1, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 256, 16, 64, 16, 8, 16, 16, 1, 1, S<1, 16, 4, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 16, 4, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 128, 16, 64, 16, 8, 16, 16, 1, 2, S<1, 16, 2, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 16, 2, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 128, 16, 32, 8, 8, 16, 16, 1, 1, S<1, 8, 4, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 8, 4, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 128, 16, 32, 4, 8, 16, 16, 1, 1, S<1, 4, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 8, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 1, 64, 16, 16, 16, 8, 16, 16, 1, 1, S<1, 16, 1, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 16, 1, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 256, 64, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 32, 1, 4>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 256, 16, 64, 8, 8, 16, 16, 1, 1, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNPadding, 2, 64, 16, 16, 16, 8, 16, 16, 1, 1, S<1, 16, 1, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 16, 1, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 16>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 16, 128, 4, 32, 16, 16, 1, 2, S<1, 4, 4, 16>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 4, 16>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 16>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 32, 32, 8, 16, 16, 16, 1, 1, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 32, 32, 4, 16, 16, 16, 1, 1, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 16, 64, 4, 16, 16, 16, 1, 1, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 16, 32, 8, 8, 16, 16, 1, 1, S<1, 2, 16, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 2, 16, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 4, 32, 16, 16, 1, 1, S<1, 1, 4, 16>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 1, 4, 16>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 256, 64, 16, 4, 16, 16, 16, 1, 1, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 4, 8, 8>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 32, 1, 4>, 4>
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 8, 8, 32, 32, 2, 2, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 16>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 16, 128, 16, 8, 16, 16, 1, 2, S<1, 16, 4, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 16, 4, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 16>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 32, 32, 16, 8, 16, 16, 1, 1, S<1, 16, 4, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 16, 4, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 32, 32, 8, 8, 16, 16, 1, 1, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 16, 64, 8, 8, 16, 16, 1, 1, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 16, 32, 8, 8, 16, 16, 1, 1, S<1, 8, 4, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 8, 4, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 8>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 16, 16, 16, 8, 16, 16, 1, 1, S<1, 16, 1, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 16, 1, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 16, 1, 4>, 4>,
DeviceGemmXdlSplitKCShuffle_LdsDirectLoad< F16, F16, F16, F32, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 2, 256, 64, 16, 8, 8, 16, 16, 1, 1, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, S<1, 8, 8, 4>, S<0, 2, 1, 3>, 3, 2, 0, 1, 1, S<1, 32, 1, 4>, 4>
// clang-format on
>;

View File

@@ -0,0 +1,534 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iomanip>
#include <iostream>
#include <typeinfo>
#include "ck/ck.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp"
#include "ck/library/tensor_operation_instance/gpu/gemm_mx.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/utility/data_type.hpp"
namespace ck {
namespace profiler {
#if 1
template <bool KLast>
void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K)
{
int MNXdlPack = 2;
int KXdlPack = 2;
int XdlMNThread = 16;
int XdlKThread = 64 / XdlMNThread;
int K0 = K / KXdlPack / XdlKThread; // KRepeat
// The 4 16x128 building blocks will be packed into 1 32x256 for F4
// The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4
// unfold the MN32xK(256/32) scale buffer
// 4 16 2 2
// To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack
// Then, MNRepeat->KRepeat
for(int n = 0; n < MN; ++n)
{
for(int k = 0; k < K; ++k)
{
int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat
int tempn = n % (XdlMNThread * MNXdlPack);
int n1 = tempn % XdlMNThread; // i XdlMNThread
int n2 = tempn / XdlMNThread; // i MNXdlPack
int k0 = k / (XdlKThread * KXdlPack); // i KRepeat
int tempk = k % (XdlKThread * KXdlPack);
int k1 = tempk % XdlKThread; // i XdlKThread
int k2 = tempk / XdlKThread; // i KXdlPack
int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 +
k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread +
k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack +
k2 * MNXdlPack + n2;
// src[n * K + k] = ck::type_convert<ck::e8m0_bexp_t>(static_cast<float>(powf(2.0f, n2 +
// k2 * MNXdlPack)));
if constexpr(KLast)
dst[outputIndex] = src[n * K + k];
else
dst[outputIndex] = src[k * MN + n];
}
}
}
void preShuffleBuffer(const ck::f4x2_pk_t* src, ck::f4x2_pk_t* dst, int N, int K, int NXdl)
{
int KPack = 16;
int NLane = NXdl;
int KLane = 64 / NLane;
int K_pk = K / 2;
int K0 = K_pk / (KLane * KPack);
// K -> K0 KLane KPack
// N -> N0 NLane
// N, K -> N0 K0 KLane NLane KPack
int tempk;
for(int n = 0; n < N; ++n)
{
for(int k = 0; k < K_pk; ++k)
{
int n0 = n / NLane;
int n1 = n % NLane;
int k0 = k / (KLane * KPack);
tempk = k % (KLane * KPack);
int k1 = tempk / KPack;
int k2 = tempk % KPack;
int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane +
k1 * KPack * NLane + n1 * KPack + k2;
dst[outputIndex] = src[n * K_pk + k];
}
}
}
#endif
template <typename ADataType,
typename BDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout,
int ScaleBlockSize>
bool profile_gemm_mx_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
int M,
int N,
int K,
int StrideA,
int StrideB,
int StrideC,
int KBatch,
int n_warmup,
int n_iter,
uint64_t rotating = 0)
{
using tensor_operation::device::instance::Col;
using tensor_operation::device::instance::E8M0;
using tensor_operation::device::instance::E8M0PK;
using tensor_operation::device::instance::MFMA;
using tensor_operation::device::instance::Row;
constexpr bool BPreShuffle = is_same_v<BLayout, MFMA>;
using BRefLayout = conditional_t<BPreShuffle, Col, BLayout>;
if(K % ScaleBlockSize != 0)
{
throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize.");
};
using XDataType = E8M0;
using XPackedDataType = E8M0PK;
using AScaleLayout = Row;
using BScaleLayout = Col;
auto f_host_tensor_descriptor =
[](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) {
using namespace ck::literals;
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
return HostTensorDescriptor({row, col}, {stride, 1});
else
return HostTensorDescriptor({row, col}, {1, stride});
};
auto f_get_default_stride =
[](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) {
if(stride == -1)
{
// give a chance if stride is -1, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
return static_cast<ck::index_t>(col);
else
return static_cast<ck::index_t>(row);
}
else
return static_cast<ck::index_t>(stride);
};
auto Scale_Padded_M = (M + 32 - 1) / 32 * 32;
auto Scale_Stride_AM =
f_get_default_stride(Scale_Padded_M, K / ScaleBlockSize, -1, AScaleLayout{});
auto Scale_Stride_BN = f_get_default_stride(K / ScaleBlockSize, N, -1, BScaleLayout{});
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
auto b_k_n =
std::make_shared<Tensor<BDataType>>(f_host_tensor_descriptor(K, N, StrideB, BRefLayout{}));
auto b_input = b_k_n;
if constexpr(BPreShuffle)
b_input = std::make_shared<Tensor<BDataType>>(
f_host_tensor_descriptor(K, N, StrideB, BRefLayout{})); // use layout only for size
// scales for A and B
Tensor<XDataType> a_m_k_scale(f_host_tensor_descriptor(
Scale_Padded_M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{}));
Tensor<XDataType> b_k_n_scale(
f_host_tensor_descriptor(K / ScaleBlockSize, N, Scale_Stride_BN, BScaleLayout{}));
// shuffled scales for A and B
Tensor<XDataType> a_shuffled_scale(f_host_tensor_descriptor(
Scale_Padded_M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{}));
Tensor<XDataType> b_shuffled_scale(
f_host_tensor_descriptor(K / ScaleBlockSize, N, Scale_Stride_BN, BScaleLayout{}));
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::size_t total_gemm_needed =
a_m_k.GetElementSpaceSizeInBytes() + b_k_n->GetElementSpaceSizeInBytes() +
a_m_k_scale.GetElementSpaceSizeInBytes() + b_k_n_scale.GetElementSpaceSizeInBytes() +
a_shuffled_scale.GetElementSpaceSizeInBytes() +
b_shuffled_scale.GetElementSpaceSizeInBytes();
int rotating_count = std::max(
1,
std::min(n_iter,
static_cast<int>(std::ceil(static_cast<double>(rotating) / total_gemm_needed))));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "a_m_k_scale: " << a_m_k_scale.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n->mDesc << std::endl;
std::cout << "b_k_n_scale: " << b_k_n_scale.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl;
std::cout << "rotating count: " << rotating_count << std::endl;
auto a_data_element = [](float x) {
if constexpr(ck::is_same_v<ADataType, ck::f4x2_pk_t>)
return ck::type_convert<ADataType>(ck::float2_t(x));
else
return ck::type_convert<ADataType>(x);
};
auto b_data_element = [](float x) {
if constexpr(ck::is_same_v<BDataType, ck::f4x2_pk_t>)
return ck::type_convert<BDataType>(ck::float2_t(x));
else
return ck::type_convert<BDataType>(x);
};
switch(init_method)
{
case 0: // Initializations for development and debugging
ck::utils::FillConstant<ADataType>{a_data_element(1.0f)}(a_m_k);
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(2.0f)}(a_m_k_scale);
ck::utils::FillConstant<BDataType>{b_data_element(0.5f)}(*b_k_n);
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(1.0f)}(b_k_n_scale);
if(do_log)
{
std::cout << "Init A = {1}" << std::endl;
std::cout << "Init A scale = {2.0}" << std::endl;
std::cout << "Init B = {0.5}" << std::endl;
std::cout << "Init B scale = {1.0}" << std::endl;
std::cout << "Expect C = {K}" << std::endl;
}
break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-4, 5}); // Z[-4,4]
b_k_n->GenerateTensorValue(GeneratorTensor_2<BDataType>{-4, 5}); // Z[-4,4]
a_m_k_scale.GenerateTensorValue(
GeneratorTensor_2<XDataType>{125, 129}); // scales: {0.25, 0.5, 1, 2}
b_k_n_scale.GenerateTensorValue(
GeneratorTensor_2<XDataType>{125, 129}); // scales: {0.25, 0.5, 1, 2}
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-2.0, 2.0});
a_m_k_scale.GenerateTensorValue(GeneratorTensor_3<XDataType>{powf(2.0f, -125.0f), 1.0f});
b_k_n->GenerateTensorValue(GeneratorTensor_3<BDataType>{-2.0, 2.0});
b_k_n_scale.GenerateTensorValue(GeneratorTensor_3<XDataType>{powf(2.0f, -125.0f), 1.0f});
break;
}
#if 1
preShuffleScaleBuffer<ck::is_same_v<ALayout, Row>>(a_m_k_scale.mData.data(),
a_shuffled_scale.mData.data(),
Scale_Padded_M,
K / ScaleBlockSize);
preShuffleScaleBuffer<ck::is_same_v<BRefLayout, Col>>(
b_k_n_scale.mData.data(), b_shuffled_scale.mData.data(), N, K / ScaleBlockSize);
if constexpr(BPreShuffle)
{
int NPerXdl = 16; // Fixed 16
preShuffleBuffer(b_k_n->mData.data(), b_input->mData.data(), N, K, NPerXdl);
}
#endif
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{};
const auto c_element_op = CElementOp{};
if(do_log > 0)
std::cout << "Device memory allocation..." << std::endl;
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.GetElementSpaceSize());
DeviceMem a_scale_device_buf(sizeof(XDataType) * a_m_k_scale.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n->GetElementSpaceSize());
DeviceMem b_scale_device_buf(sizeof(XDataType) * b_k_n_scale.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.GetElementSpaceSize());
if(do_log > 0)
std::cout << "Upload data to device..." << std::endl;
a_device_buf.ToDevice(a_m_k.mData.data());
a_scale_device_buf.ToDevice(a_shuffled_scale.mData.data());
b_device_buf.ToDevice(b_input->mData.data());
b_scale_device_buf.ToDevice(b_shuffled_scale.mData.data());
if(do_log > 0)
std::cout << "Done." << std::endl;
using DeviceOp = ck::tensor_operation::device::DeviceGemmMX<ALayout,
BLayout,
CLayout,
ADataType,
XPackedDataType,
BDataType,
XPackedDataType,
CDataType,
ScaleBlockSize,
AElementOp,
BElementOp,
CElementOp>;
std::cout << "finding op instances..." << std::endl;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
// Run reference GEMM
if(do_verification)
{
using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMXGemm< //
ADataType,
BDataType,
CDataType,
float, // AccDataType
XDataType,
AElementOp,
BElementOp,
CElementOp,
float, // ComputeTypeA
float // ComputeTypeB
>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_m_k,
a_m_k_scale,
*b_k_n,
b_k_n_scale,
c_m_n_host_result,
a_element_op,
b_element_op,
c_element_op);
ref_invoker.Run(ref_argument);
}
std::string best_op_name;
std::optional<std::string> best_op_object_name;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
float best_kbatch = 0;
bool pass = true;
// profile device GEMM instances
for(auto& op_ptr : op_ptrs)
{
std::vector<int> kbatch_list = {1, 2, 4, 8, 16, 19, 32, 38}; // use these when KBatch <= 0
if(KBatch > 0)
{
kbatch_list = {KBatch};
}
for(std::size_t i = 0; i < kbatch_list.size(); i++)
{
auto kbatch_curr = kbatch_list[i];
auto argument_ptr = op_ptr->MakeArgumentPointer(
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<XPackedDataType*>(a_scale_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<XPackedDataType*>(b_scale_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
Scale_Stride_AM,
StrideB,
Scale_Stride_BN,
StrideC,
kbatch_curr,
a_element_op,
b_element_op,
c_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
// re-init C to zero before profiling next kernel
c_device_buf.SetZero();
invoker_ptr->Run(argument_ptr.get(),
StreamConfig{nullptr, false, 0, n_warmup, n_iter});
if(do_verification)
{
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
if(do_log)
{
if(init_method == 0)
{
auto expected = static_cast<float>(K);
auto computed = type_convert<float>(c_m_n_device_result(0, 12));
pass = pass & (std::abs(expected - computed) <= 0.0f);
std::cout << "\nExpected vs Computed: " << expected << " vs "
<< computed << ((pass) ? " (PASSED!)" : " (FAILED!)")
<< std::endl
<< std::endl;
}
else
{
if constexpr(is_same_v<ADataType, ck::f8_t> ||
is_same_v<ADataType, ck::bf8_t>)
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",")
<< "\n";
else
std::cout << "A: WIP PRINT PACKED TYPE\n";
LogRangeAsType<float>(std::cout << "a_scale : ", a_m_k_scale.mData, ",")
<< "\n";
if constexpr(is_same_v<BDataType, ck::f8_t> ||
is_same_v<BDataType, ck::bf8_t>)
LogRangeAsType<float>(std::cout << "b : ", b_k_n->mData, ",")
<< "\n";
else
std::cout << "B: WIP PRINT PACKED TYPE\n";
LogRangeAsType<float>(std::cout << "b_scale: ", b_k_n_scale.mData, ",")
<< "\n";
LogRangeAsType<float>(
std::cout << "c_host : ", c_m_n_host_result.mData, ",")
<< "\n";
LogRangeAsType<float>(
std::cout << "c_device: ", c_m_n_device_result.mData, ",")
<< std::endl;
}
}
pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
}
std::string op_name = op_ptr->GetTypeString();
std::optional<std::string> op_obj_name = op_ptr->GetObjectName();
float ave_time = invoker_ptr->Run(argument_ptr.get(),
StreamConfig{nullptr,
time_kernel,
0,
n_warmup,
n_iter,
rotating_count > 1,
rotating_count});
// Output size(M*N) * [dot product(2K) + product of scales(K/ScaleBlockSize) +
// scaling of partial sums(K/ScaleBlockSize)]
// FLOPS = 2 * M * N * K + 2 * M * N * K / ScaleBlockSize
std::size_t flop =
std::size_t(2) * M * N * K + std::size_t(2) * M * N * K / ScaleBlockSize;
// TODO: fp6?
std::size_t num_btype = sizeof(ADataType) * M * K / packed_size_v<ADataType> +
sizeof(BDataType) * K * N / packed_size_v<BDataType> +
sizeof(CDataType) * M * N +
sizeof(XDataType) * (M * K + K * N) / ScaleBlockSize;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops
<< " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", KBatch "
<< kbatch_curr << std::endl;
if(tflops > best_tflops && ave_time > 1e-10)
{
best_op_name = op_name;
best_op_object_name = op_obj_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
best_kbatch = kbatch_curr;
}
}
else
{
std::cout << op_ptr->GetTypeString() << " does not support this problem"
<< std::endl;
}
}
}
if constexpr(is_same<CDataType, float>::value)
{
std::cout << "Best Perf for datatype = f32";
}
else if constexpr(is_same<CDataType, half_t>::value)
{
std::cout << "Best Perf for datatype = f16";
}
else if constexpr(is_same<CDataType, bhalf_t>::value)
{
std::cout << "Best Perf for datatype = bf16";
}
std::cout << " ALayout = " << ALayout::name;
std::cout << " BLayout = " << BLayout::name;
std::cout << " CLayout = " << CLayout::name;
std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA
<< " StrideB = " << StrideB << " StrideC = " << StrideC << " KBatch = " << best_kbatch
<< " : " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec
<< " GB/s, " << best_op_name << std::endl;
if(best_op_object_name)
std::cout << best_op_object_name.value() << std::endl;
return pass;
}
} // namespace profiler
} // namespace ck

View File

@@ -63,6 +63,9 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
list(APPEND PROFILER_OPS profile_gemm_multiply_multiply_wp.cpp)
list(APPEND PROFILER_OPS profile_gemm_ab_scale.cpp)
endif()
if(SUPPORTED_GPU_TARGETS MATCHES "gfx95")
list(APPEND PROFILER_OPS profile_gemm_mx.cpp)
endif()
list(APPEND PROFILER_OPS profile_batched_gemm.cpp)
list(APPEND PROFILER_OPS profile_batched_gemm_reduce.cpp)
list(APPEND PROFILER_OPS profile_gemm_add_multiply.cpp)
@@ -168,6 +171,9 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9")
list(APPEND DEVICE_INSTANCES device_gemm_multiply_multiply_wp_instance)
list(APPEND DEVICE_INSTANCES device_gemm_ab_scale_instance)
endif()
if(SUPPORTED_GPU_TARGETS MATCHES "gfx95")
list(APPEND DEVICE_INSTANCES device_gemm_mx_instance)
endif()
list(APPEND DEVICE_INSTANCES device_gemm_splitk_instance)
list(APPEND DEVICE_INSTANCES device_gemm_b_scale_instance)
list(APPEND DEVICE_INSTANCES device_batched_gemm_b_scale_instance)

View File

@@ -0,0 +1,155 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "profiler/profile_gemm_mx_impl.hpp"
#include "profiler_operation_registry.hpp"
enum struct GemmMatrixLayout
{
MK_KN_MN, // 0
MK_NK_MN, // 1
MK_MFMA_MN, // 2
};
enum struct GemmDataType
{
F4_F4_F16, // 0
F8_F8_F16, // 1
F8_F8_BF16, // 2
};
#define OP_NAME "gemm_mx"
#define OP_DESC "GEMM_mx"
int profile_gemm_mx(int argc, char* argv[])
{
if(argc != 11 && argc != 14 && argc != 18)
{
printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n");
printf("arg2: data type (0: f4->f16 ;\n");
printf(" 1: fp8->f16 ;\n");
printf(" 2: fp8->bf16 )\n");
printf("arg3: matrix layout (0: A[m, k] * B[k, n] = C[m, n] ;\n");
printf(" 1: A[m, k] * B[n, k] = C[m, n] ;\n");
printf(" 2: A[k, m] * BPreShuff = C[m, n])\n");
printf("arg4: verification (0: no; 1: yes)\n");
printf("arg5: initialization (0: no init; 1: integer value; 2: decimal value)\n");
printf("arg6: print tensor value (0: no; 1: yes)\n");
printf("arg7: time kernel (0=no, 1=yes)\n");
printf("arg8 to 13: M, N, K, StrideA, StrideB, StrideC\n");
printf("optional:\n");
printf("arg14: number of kbatch (default 1)\n");
printf("arg15: number of warm-up cycles (default 1)\n");
printf("arg16: number of iterations (default 10)\n");
printf("arg17: memory for rotating buffer (default 0, size in MB)\n");
exit(1);
}
int arg_index = 2;
const auto data_type = static_cast<GemmDataType>(std::stoi(argv[arg_index++]));
const auto layout = static_cast<GemmMatrixLayout>(std::stoi(argv[arg_index++]));
const bool do_verification = std::stoi(argv[arg_index++]);
const int init_method = std::stoi(argv[arg_index++]);
const bool do_log = std::stoi(argv[arg_index++]);
const bool time_kernel = std::stoi(argv[arg_index++]);
const int M = std::stoi(argv[arg_index++]);
const int N = std::stoi(argv[arg_index++]);
const int K = std::stoi(argv[arg_index++]);
int StrideA = -1, StrideB = -1, StrideC = -1;
if(argc > arg_index)
{
StrideA = std::stoi(argv[arg_index++]);
StrideB = std::stoi(argv[arg_index++]);
StrideC = std::stoi(argv[arg_index++]);
}
int KBatch = 1;
int n_warmup = 1;
int n_iter = 10;
uint64_t rotating = 0;
if(argc > arg_index)
{
KBatch = std::stoi(argv[arg_index++]);
n_warmup = std::stoi(argv[arg_index++]);
n_iter = std::stoi(argv[arg_index++]);
rotating = std::stoull(argv[arg_index++]) * 1024 * 1024;
}
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F4 = ck::f4x2_pk_t;
using F8 = ck::f8_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using MFMA = ck::tensor_layout::gemm::MFMA;
auto profile =
[&](auto a_type, auto b_type, auto c_type, auto a_layout, auto b_layout, auto c_layout) {
using ADataType = decltype(a_type);
using BDataType = decltype(b_type);
using CDataType = decltype(c_type);
using ALayout = decltype(a_layout);
using BLayout = decltype(b_layout);
using CLayout = decltype(c_layout);
const int DefaultStrideA = ck::is_same_v<ALayout, Row> ? K : M;
const int DefaultStrideB = ck::is_same_v<BLayout, Row> ? N : K;
const int DefaultStrideC = ck::is_same_v<CLayout, Row> ? N : M;
bool pass = ck::profiler::profile_gemm_mx_impl<ADataType,
BDataType,
CDataType,
ALayout,
BLayout,
CLayout,
32>( //
do_verification,
init_method,
do_log,
time_kernel,
M,
N,
K,
(StrideA < 0) ? DefaultStrideA : StrideA,
(StrideB < 0) ? DefaultStrideB : StrideB,
(StrideC < 0) ? DefaultStrideC : StrideC,
KBatch,
n_warmup,
n_iter,
rotating);
return pass ? 0 : 1;
};
if(data_type == GemmDataType::F4_F4_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
return profile(F4{}, F4{}, F16{}, Row{}, Col{}, Row{});
}
else if(data_type == GemmDataType::F4_F4_F16 && layout == GemmMatrixLayout::MK_MFMA_MN)
{
return profile(F4{}, F4{}, F16{}, Row{}, MFMA{}, Row{});
}
else if(data_type == GemmDataType::F8_F8_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
return profile(F8{}, F8{}, F16{}, Row{}, Col{}, Row{});
}
else if(data_type == GemmDataType::F8_F8_BF16 && layout == GemmMatrixLayout::MK_NK_MN)
{
return profile(F8{}, F8{}, BF16{}, Row{}, Col{}, Row{});
}
else
{
std::cout << "this data_type & layout is not implemented" << std::endl;
return 1;
}
}
REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_gemm_mx);

View File

@@ -12,7 +12,7 @@ using F8 = ck::f8_t;
using BF8 = ck::bf8_t;
using F6 = ck::f6_t;
using BF6 = ck::bf6_t;
using F4 = ck::f4_t;
using F4 = ck::f4x2_pk_t;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using F32 = float;
@@ -52,22 +52,23 @@ class TestGemmMX_KM_NK
};
// clang-format off
using KernelTypes_F8_MK_NK = ::testing::Types<
using KernelTypes_MK_NK = ::testing::Types<
#if defined(CK_ENABLE_FP8)
// ADataType, BDataType, CDataType, ScaleBlockSize
std::tuple< F8, F8, F16, ck::Number<32> >,
std::tuple< F8, F8, BF16, ck::Number<32> >
std::tuple< F8, F8, BF16, ck::Number<32> >,
#endif
std::tuple< F4, F4, F16, ck::Number<32> >
>;
using KernelTypes_BF8_F8_MK_KN = ::testing::Types<
using KernelTypes_MK_KN = ::testing::Types<
#if defined(CK_ENABLE_FP8)
// ADataType, BDataType, CDataType, ScaleBlockSize
std::tuple< BF8, F8, F16, ck::Number<32> >
#endif
>;
using KernelTypes_F8_KM_NK = ::testing::Types<
using KernelTypes_KM_NK = ::testing::Types<
#if defined(CK_ENABLE_FP8)
// ADataType, BDataType, CDataType, ScaleBlockSize
std::tuple< F8, F8, BF16, ck::Number<32> >
@@ -75,9 +76,9 @@ using KernelTypes_F8_KM_NK = ::testing::Types<
>;
// clang-format on
TYPED_TEST_SUITE(TestGemmMX_MK_NK, KernelTypes_F8_MK_NK);
TYPED_TEST_SUITE(TestGemmMX_MK_KN, KernelTypes_BF8_F8_MK_KN);
TYPED_TEST_SUITE(TestGemmMX_KM_NK, KernelTypes_F8_KM_NK);
TYPED_TEST_SUITE(TestGemmMX_MK_NK, KernelTypes_MK_NK);
TYPED_TEST_SUITE(TestGemmMX_MK_KN, KernelTypes_MK_KN);
TYPED_TEST_SUITE(TestGemmMX_KM_NK, KernelTypes_KM_NK);
/// A: RowMajor
/// B: ColMajor
@@ -214,7 +215,8 @@ TYPED_TEST(TestGemmMX_MK_KN, Large)
TYPED_TEST(TestGemmMX_KM_NK, SmallN)
{
constexpr int M = 256;
std::vector<int> Ns{1, 2, 3, 4, 5, 6};
std::vector<int> Ns{32, 64};
// std::vector<int> Ns{1, 2, 3, 4, 5, 6};
constexpr int K = 512;
constexpr int StrideA = M;
@@ -222,16 +224,16 @@ TYPED_TEST(TestGemmMX_KM_NK, SmallN)
for(int N : Ns)
{
const auto new_N = N * 8;
const auto StrideC = new_N;
this->Run(M, new_N, K, StrideA, StrideB, StrideC);
const auto StrideC = N;
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
}
TYPED_TEST(TestGemmMX_KM_NK, MidLargeN)
{
constexpr int M = 256;
std::vector<int> Ns{127, 255, 312, 799, 1573};
std::vector<int> Ns{128, 256, 2048};
// std::vector<int> Ns{127, 255, 312, 799, 1573};
constexpr int K = 512;
constexpr int StrideA = M;
@@ -239,9 +241,8 @@ TYPED_TEST(TestGemmMX_KM_NK, MidLargeN)
for(int N : Ns)
{
const auto new_N = (N + 7) / 8 * 8;
const auto StrideC = new_N;
this->Run(M, new_N, K, StrideA, StrideB, StrideC);
const auto StrideC = N;
this->Run(M, N, K, StrideA, StrideB, StrideC);
}
}

View File

@@ -18,6 +18,7 @@
#include "ck/library/tensor_operation_instance/gpu/gemm_mx.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "profiler/profile_gemm_mx_impl.hpp"
namespace ck {
namespace test {
@@ -27,401 +28,6 @@ using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
} // namespace
template <typename ADataType,
typename BDataType,
typename CDataType,
typename ALayout,
typename BLayout,
typename CLayout,
int ScaleBlockSize>
bool profile_gemm_mx_impl(int do_verification,
int init_method,
bool do_log,
bool time_kernel,
int M,
int N,
int K,
int StrideA,
int StrideB,
int StrideC,
int KBatch,
int n_warmup,
int n_iter,
uint64_t rotating = 0)
{
if(K % ScaleBlockSize != 0)
{
throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize.");
};
using ScaleDataType = e8m0_bexp_t;
using AScaleLayout = Row;
using BScaleLayout = Col;
bool pass = true;
auto f_host_tensor_descriptor =
[](std::size_t row, std::size_t col, std::size_t stride, auto layout) {
using namespace ck::literals;
if(is_same<decltype(layout), tensor_layout::gemm::RowMajor>::value)
{
return HostTensorDescriptor({row, col}, {stride, 1_uz});
}
else
{
return HostTensorDescriptor({row, col}, {1_uz, stride});
}
};
auto f_get_default_stride =
[](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) {
if(stride == -1)
{
// give a chance if stride is -1, return a default packed stride
if constexpr(std::is_same_v<decltype(layout), ck::tensor_layout::gemm::RowMajor>)
{
return static_cast<ck::index_t>(col);
}
else
{
return static_cast<ck::index_t>(row);
}
}
else
return static_cast<ck::index_t>(stride);
};
auto Scale_Stride_AM = f_get_default_stride(M, K / ScaleBlockSize, -1, AScaleLayout{});
auto Scale_Stride_BN = f_get_default_stride(K / ScaleBlockSize, N, -1, BScaleLayout{});
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
Tensor<ScaleDataType> a_m_k_scale(f_host_tensor_descriptor(
M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{})); // scales for A
Tensor<ScaleDataType> b_k_n_scale(f_host_tensor_descriptor(
K / ScaleBlockSize, N, Scale_Stride_BN, BScaleLayout{})); // scales for B
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
Tensor<CDataType> c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
std::size_t total_gemm_needed =
a_m_k.GetElementSpaceSizeInBytes() + b_k_n.GetElementSpaceSizeInBytes() +
a_m_k_scale.GetElementSpaceSizeInBytes() + b_k_n_scale.GetElementSpaceSizeInBytes();
int rotating_count = std::max(
1,
std::min(n_iter,
static_cast<int>(std::ceil(static_cast<double>(rotating) / total_gemm_needed))));
std::cout << "a_m_k: " << a_m_k.mDesc << std::endl;
std::cout << "a_m_k_scale: " << a_m_k_scale.mDesc << std::endl;
std::cout << "b_k_n: " << b_k_n.mDesc << std::endl;
std::cout << "b_k_n_scale: " << b_k_n_scale.mDesc << std::endl;
std::cout << "c_m_n: " << c_m_n_device_result.mDesc << std::endl;
std::cout << "rotating count: " << rotating_count << std::endl;
switch(init_method)
{
case 0: // Initializations for development and debugging
ck::utils::FillConstant<ADataType>{ck::type_convert<ADataType>(1.0f)}(a_m_k);
ck::utils::FillConstant<ScaleDataType>{ck::type_convert<ScaleDataType>(2.0f)}(a_m_k_scale);
ck::utils::FillConstant<BDataType>{ck::type_convert<BDataType>(0.5f)}(b_k_n);
ck::utils::FillConstant<ScaleDataType>{ck::type_convert<ScaleDataType>(1.0f)}(b_k_n_scale);
if(do_log)
{
std::cout << "Init A = {1}" << std::endl;
std::cout << "Init A scale = {2.0}" << std::endl;
std::cout << "Init B = {0.5}" << std::endl;
std::cout << "Init B scale = {1.0}" << std::endl;
std::cout << "Expect C = {K}" << std::endl;
}
break;
case 1:
a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-4, 5}); // Z[-4,4]
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-4, 5}); // Z[-4,4]
a_m_k_scale.GenerateTensorValue(
GeneratorTensor_2<ScaleDataType>{125, 129}); // scales: {0.25, 0.5, 1, 2}
b_k_n_scale.GenerateTensorValue(
GeneratorTensor_2<ScaleDataType>{125, 129}); // scales: {0.25, 0.5, 1, 2}
break;
default:
a_m_k.GenerateTensorValue(GeneratorTensor_3<ADataType>{-2.0, 2.0});
a_m_k_scale.GenerateTensorValue(
GeneratorTensor_3<ScaleDataType>{powf(2.0f, -125.0f), 1.0f}); // R[2^-125, 1]
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-2.0, 2.0});
b_k_n_scale.GenerateTensorValue(
GeneratorTensor_3<ScaleDataType>{powf(2.0f, -125.0f), 1.0f});
break;
}
using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using BElementOp = ck::tensor_operation::element_wise::PassThrough;
using CElementOp = ck::tensor_operation::element_wise::PassThrough;
const auto a_element_op = AElementOp{};
const auto b_element_op = BElementOp{};
const auto c_element_op = CElementOp{};
if(do_log > 0)
std::cout << "Device memory allocation..." << std::endl;
DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize());
DeviceMem a_scale_device_buf(sizeof(ScaleDataType) * a_m_k_scale.mDesc.GetElementSpaceSize());
DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize());
DeviceMem b_scale_device_buf(sizeof(ScaleDataType) * b_k_n_scale.mDesc.GetElementSpaceSize());
DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize());
if(do_log > 0)
std::cout << "Upload data to device..." << std::endl;
a_device_buf.ToDevice(a_m_k.mData.data());
a_scale_device_buf.ToDevice(a_m_k_scale.mData.data());
b_device_buf.ToDevice(b_k_n.mData.data());
b_scale_device_buf.ToDevice(b_k_n_scale.mData.data());
if(do_log > 0)
std::cout << "Done." << std::endl;
using DeviceOp = ck::tensor_operation::device::DeviceGemmMX<ALayout,
BLayout,
CLayout,
ADataType,
ScaleDataType,
BDataType,
ScaleDataType,
CDataType,
ScaleBlockSize,
AElementOp,
BElementOp,
CElementOp>;
// get device op instances
const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory<
DeviceOp>::GetInstances();
std::cout << "found " << op_ptrs.size() << " instances" << std::endl;
// Run reference GEMM
if(do_verification)
{
using ReferenceGemmInstance =
ck::tensor_operation::host::ReferenceMXGemm<ADataType,
BDataType,
CDataType,
float, // AccDataType
ScaleDataType,
AElementOp,
BElementOp,
CElementOp,
float, // ComputeTypeA
float // ComputeTypeB
>;
auto ref_gemm = ReferenceGemmInstance{};
auto ref_invoker = ref_gemm.MakeInvoker();
auto ref_argument = ref_gemm.MakeArgument(a_m_k,
a_m_k_scale,
b_k_n,
b_k_n_scale,
c_m_n_host_result,
a_element_op,
b_element_op,
c_element_op);
ref_invoker.Run(ref_argument);
}
std::string best_op_name;
std::optional<std::string> best_op_object_name;
float best_ave_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
float best_kbatch = 0;
// profile device GEMM instances
for(auto& op_ptr : op_ptrs)
{
std::vector<int> kbatch_list = {1, 2, 4, 8, 16, 19, 32, 38}; // use these when KBatch <= 0
if(KBatch > 0)
{
kbatch_list = {KBatch};
}
for(std::size_t i = 0; i < kbatch_list.size(); i++)
{
auto kbatch_curr = kbatch_list[i];
auto argument_ptr = op_ptr->MakeArgumentPointer(
static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<ScaleDataType*>(a_scale_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<ScaleDataType*>(b_scale_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
Scale_Stride_AM,
StrideB,
Scale_Stride_BN,
StrideC,
kbatch_curr,
a_element_op,
b_element_op,
c_element_op);
auto invoker_ptr = op_ptr->MakeInvokerPointer();
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
// re-init C to zero before profiling next kernel
c_device_buf.SetZero();
invoker_ptr->Run(argument_ptr.get(),
StreamConfig{nullptr, false, 0, n_warmup, n_iter});
if(do_verification)
{
c_device_buf.FromDevice(c_m_n_device_result.mData.data());
if(do_log)
{
if(init_method == 0)
{
auto expected = static_cast<float>(K);
auto computed = type_convert<float>(c_m_n_device_result(0, 12));
pass = pass & (std::abs(expected - computed) <= 0.0f);
std::cout << "\nExpected vs Computed: " << expected << " vs "
<< computed << ((pass) ? " (PASSED!)" : " (FAILED!)")
<< std::endl
<< std::endl;
}
else
{
LogRangeAsType<float>(std::cout << "a : ", a_m_k.mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "a_scale : ", a_m_k_scale.mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "b: ", b_k_n.mData, ",")
<< std::endl;
LogRangeAsType<float>(std::cout << "b_scale: ", b_k_n_scale.mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "c_host : ", c_m_n_host_result.mData, ",")
<< std::endl;
LogRangeAsType<float>(
std::cout << "c_device: ", c_m_n_device_result.mData, ",")
<< std::endl;
}
}
pass = pass & ck::utils::check_err(c_m_n_device_result, c_m_n_host_result);
}
std::string op_name = op_ptr->GetTypeString();
std::optional<std::string> op_obj_name = op_ptr->GetObjectName();
float ave_time = invoker_ptr->Run(argument_ptr.get(),
StreamConfig{nullptr,
time_kernel,
0,
n_warmup,
n_iter,
rotating_count > 1,
rotating_count});
// Output size(M*N) * [dot product(2K) + product of scales(K/ScaleBlockSize) +
// scaling of partial sums(K/ScaleBlockSize)]
// FLOPS = 2 * M * N * K + 2 * M * N * K / ScaleBlockSize
std::size_t flop =
std::size_t(2) * M * N * K + std::size_t(2) * M * N * K / ScaleBlockSize;
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N +
sizeof(ScaleDataType) * (M * K + K * N) / ScaleBlockSize;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops
<< " TFlops, " << gb_per_sec << " GB/s, " << op_name << ", KBatch "
<< kbatch_curr << std::endl;
if(tflops > best_tflops && ave_time > 1e-10)
{
best_op_name = op_name;
best_op_object_name = op_obj_name;
best_tflops = tflops;
best_ave_time = ave_time;
best_gb_per_sec = gb_per_sec;
best_kbatch = kbatch_curr;
}
}
else
{
std::cout << op_ptr->GetTypeString() << " does not support this problem"
<< std::endl;
}
}
}
if constexpr(is_same<CDataType, float>::value)
{
std::cout << "Best Perf for datatype = f32";
}
else if constexpr(is_same<CDataType, half_t>::value)
{
std::cout << "Best Perf for datatype = f16";
}
else if constexpr(is_same<CDataType, bhalf_t>::value)
{
std::cout << "Best Perf for datatype = bf16";
}
else if constexpr(is_same<CDataType, int8_t>::value)
{
std::cout << "Best Perf for datatype = int8";
}
if constexpr(is_same<ALayout, tensor_layout::gemm::RowMajor>::value)
{
std::cout << " ALayout = RowMajor";
}
else if constexpr(is_same<ALayout, tensor_layout::gemm::ColumnMajor>::value)
{
std::cout << " ALayout = ColumnMajor";
}
if constexpr(is_same<BLayout, tensor_layout::gemm::RowMajor>::value)
{
std::cout << " BLayout = RowMajor";
}
else if constexpr(is_same<BLayout, tensor_layout::gemm::ColumnMajor>::value)
{
std::cout << " BLayout = ColumnMajor";
}
std::cout << " M = " << M << " N = " << N << " K = " << K << " StrideA = " << StrideA
<< " StrideB = " << StrideB << " StrideC = " << StrideC << " KBatch = " << best_kbatch
<< " : " << best_ave_time << " ms, " << best_tflops << " TFlops, " << best_gb_per_sec
<< " GB/s, " << best_op_name << std::endl;
if(best_op_object_name)
std::cout << best_op_object_name.value() << std::endl;
return pass;
}
template <typename Tuple>
class TestGemmMX : public testing::Test
{
@@ -471,25 +77,25 @@ class TestGemmMX : public testing::Test
int n_warmup = 1,
int n_iter = 10)
{
bool pass = ck::test::profile_gemm_mx_impl<ADataType,
BDataType,
CDataType,
ALayout,
BLayout,
CLayout,
ScaleBlockSize>(verify_,
init_method_,
log_,
bench_,
M,
N,
K,
StrideA,
StrideB,
StrideC,
kbatch,
n_warmup,
n_iter);
bool pass = ck::profiler::profile_gemm_mx_impl<ADataType,
BDataType,
CDataType,
ALayout,
BLayout,
CLayout,
ScaleBlockSize>(verify_,
init_method_,
log_,
bench_,
M,
N,
K,
StrideA,
StrideB,
StrideC,
kbatch,
n_warmup,
n_iter);
EXPECT_TRUE(pass);
}
};

View File

@@ -74,7 +74,11 @@ struct mfma_scale_type_selector<16, 16>
AccumFragT& fragAcc)
{
auto op = mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>{};
op.template run<16, 16>(fragA, scale_a[Number<0>{}], fragB, scale_b[Number<0>{}], fragAcc);
op.template run<16, 16, 0, 0>(fragA,
ck::utils::get_exponent_value(scale_a[Number<0>{}]),
fragB,
ck::utils::get_exponent_value(scale_b[Number<0>{}]),
fragAcc);
}
};
@@ -93,7 +97,11 @@ struct mfma_scale_type_selector<32, 32>
AccumFragT& fragAcc)
{
auto op = mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4>{};
op.template run<32, 32>(fragA, scale_a[Number<0>{}], fragB, scale_b[Number<0>{}], fragAcc);
op.template run<32, 32, 0, 0>(fragA,
ck::utils::get_exponent_value(scale_a[Number<0>{}]),
fragB,
ck::utils::get_exponent_value(scale_b[Number<0>{}]),
fragAcc);
}
};
@@ -921,14 +929,12 @@ template <typename AType,
typename ALayout,
typename BLayout,
typename CLayout>
__global__ void matmul(const typename packed_type<AType>::type* a,
const typename packed_type<BType>::type* b,
CType* c)
__global__ void matmul(const packed_type_t<AType>* a, const packed_type_t<BType>* b, CType* c)
{
using PackedAType = typename packed_type<AType>::type;
constexpr auto packed_size_a = packed_type<AType>::packed_size;
using PackedBType = typename packed_type<BType>::type;
constexpr auto packed_size_b = packed_type<BType>::packed_size;
using PackedAType = packed_type_t<AType>;
constexpr auto packed_size_a = packed_size_v<PackedAType>;
using PackedBType = packed_type_t<BType>;
constexpr auto packed_size_b = packed_size_v<PackedBType>;
constexpr int WAVE_SIZE = 64;
assert(threadIdx.x < WAVE_SIZE);
@@ -1005,9 +1011,9 @@ __global__ void matmul(const packed_type_t<AType>* a,
CType* c)
{
using PackedAType = packed_type_t<AType>;
constexpr auto packed_size_a = packed_size_v<AType>;
constexpr auto packed_size_a = packed_size_v<PackedAType>;
using PackedBType = packed_type_t<BType>;
constexpr auto packed_size_b = packed_size_v<BType>;
constexpr auto packed_size_b = packed_size_v<PackedBType>;
constexpr int WAVE_SIZE = 64;
assert(threadIdx.x < WAVE_SIZE);
@@ -1181,10 +1187,10 @@ template <typename DeviceMFMA,
index_t BLOCK_X>
struct TestMXMFMA
{
using PackedAType = typename packed_type<ADataType>::type;
static constexpr auto packed_size_a = packed_type<ADataType>::packed_size;
using PackedBType = typename packed_type<BDataType>::type;
static constexpr auto packed_size_b = packed_type<BDataType>::packed_size;
using PackedAType = packed_type_t<ADataType>;
static constexpr auto packed_size_a = packed_size_v<PackedAType>;
using PackedBType = packed_type_t<BDataType>;
static constexpr auto packed_size_b = packed_size_v<PackedBType>;
auto PrepareGemmTensors(const GemmParams& params, index_t init)
{
@@ -1384,11 +1390,10 @@ template <typename DeviceMFMA,
index_t BLOCK_K>
struct TestMFMA
{
using PackedAType = typename packed_type<ADataType>::type;
static constexpr auto packed_size_a = packed_type<ADataType>::packed_size;
using PackedBType = typename packed_type<BDataType>::type;
static constexpr auto packed_size_b = packed_type<BDataType>::packed_size;
using PackedAType = packed_type_t<ADataType>;
static constexpr auto packed_size_a = packed_size_v<PackedAType>;
using PackedBType = packed_type_t<BDataType>;
static constexpr auto packed_size_b = packed_size_v<PackedBType>;
auto PrepareGemmTensors(const GemmParams& params, index_t init)
{