Introduce MX GEMM for FP8 data type (#2000)

[ROCm/composable_kernel commit: 6660dc6b8e]
This commit is contained in:
Andriy Roshchenko
2025-03-24 15:41:07 -06:00
committed by GitHub
parent fd151c05d9
commit bbdd7f6d57
11 changed files with 4129 additions and 135 deletions

View File

@@ -144,7 +144,7 @@ function(clang_tidy_check TARGET)
# COMMAND ${CLANG_TIDY_COMMAND} $<JOIN:$<TARGET_PROPERTY:${TARGET},SOURCES>, >
foreach(SOURCE ${SOURCES})
if((NOT "${SOURCE}" MATCHES "(h|hpp|hxx)$") AND (NOT "${SOURCE}" MATCHES "TARGET_OBJECTS"))
string(MAKE_C_IDENTIFIER "${SOURCE}" tidy_file)
string(MD5 tidy_file "${SOURCE}")
set(tidy_target tidy-target-${TARGET}-${tidy_file})
add_custom_target(${tidy_target}
# for some targets clang-tidy not able to get information from .clang-tidy

View File

@@ -9,20 +9,17 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/utility/blkgemmpipe_scheduler.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/utility/sequence.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
using ScaleDataType = ck::e8m0_bexp_t;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
@@ -31,6 +28,8 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ck::type_convert;
struct ExecutionConfig final
{
int do_verification = 1; // (0=no, 1=CPU)
@@ -39,8 +38,9 @@ struct ExecutionConfig final
int verbosity = 0; // (0=no info, 1=verbose info)
};
struct ProblemSize final
struct ProblemSizeSplitK final
{
ck::index_t M = 3840;
ck::index_t N = 4096;
ck::index_t K = 4096;
@@ -48,9 +48,14 @@ struct ProblemSize final
ck::index_t StrideA = -1;
ck::index_t StrideB = -1;
ck::index_t StrideC = -1;
ck::index_t KBatch = 1;
};
bool parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config)
bool parse_cmd_args(int argc,
char* argv[],
ProblemSizeSplitK& problem_size,
ExecutionConfig& config)
{
if(argc == 1)
{
@@ -63,7 +68,7 @@ bool parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, Execution
config.time_kernel = std::stoi(argv[3]);
config.verbosity = std::stoi(argv[4]);
}
else if(argc == 11)
else if(argc >= 11)
{
config.do_verification = std::stoi(argv[1]);
config.init_method = std::stoi(argv[2]);
@@ -77,6 +82,11 @@ bool parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, Execution
problem_size.StrideA = std::stoi(argv[8]);
problem_size.StrideB = std::stoi(argv[9]);
problem_size.StrideC = std::stoi(argv[10]);
if(argc >= 12)
{
problem_size.KBatch = std::stoi(argv[11]);
}
}
else
{
@@ -85,7 +95,8 @@ bool parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, Execution
<< std::endl
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
<< "arg4: verbosity (0=no info, 1=verbose info)" << std::endl
<< "arg5 to 10: M (16x), N(16x), K(16x), StrideA, StrideB, StrideC" << std::endl;
<< "arg5 to 10: M(256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl
<< "arg11: KBatch" << std::endl;
return false;
}
@@ -99,56 +110,70 @@ template <typename ADataType,
typename ALayout,
typename BLayout,
typename CLayout,
typename CElementWiseOp,
typename AElementOp,
typename BElementOp,
typename CElementOp,
typename AccDataType,
typename CShuffleDataType,
ck::index_t MXVectorSize>
bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& config)
{
using ELayout = CLayout;
using DsLayout = ck::Tuple<>;
using DsDataType = ck::Tuple<>;
using AElementOp = PassThrough;
using BElementOp = PassThrough;
using CDEElementOp = CElementWiseOp;
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
static constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave;
static constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3;
static constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v1;
#if 1
// XXX: These parameters should not exist in MX-native GEMM kernel
static constexpr ck::index_t Scale_Block_M = 128;
static constexpr ck::index_t Scale_Block_N = 128;
#endif
static constexpr ck::index_t Scale_Block_K = MXVectorSize;
static constexpr ck::index_t ScaleBlockSize = MXVectorSize;
// XXX: DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 is not designed to utilize MX-specific MFMA
// instructions.
//
// XXX: DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 is not designed to utilize device-optimized
// scaled type convert functions.
//
// XXX: In DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3, KPerBlock is expected to be equal to
// ScaleBlockK (aka MXVectorSize).
// Additionally, the following is also expected:
// static_assert(ScaleBlockM % MPerBlock == 0);
// static_assert(ScaleBlockN % NPerBlock == 0);
// In MX-native GEMM kernel these requirements should be relaxed.
//
// XXX: It appears, by default we are using mfma_f32_16x16x4xf32
// MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>::selected_mfma.k_per_blk =
// MfmaSelector<float, 16, 16, float>::selected_mfma.k_per_blk = mfma_f32_16x16x4xf32
// XXX: GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 assumes scale type is float
// clang-format off
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
// ######| ALayout| BLayout| DsLayout| CLayout| ADataType| AScale| BDataType| BScale| DsDataType| CDataType| GemmAcc| CShuffleDataType|AElementwise|BElementwise| CElementwise| GemmSpec|Block| ScaleBlockM| ScaleBlockN| ScaleBlockK| M| N| K| AK1| BK1| M| N|MXdl|NXdl|ABlockTransfer|ABlockTransfer|ABlockTransfer|ABlockTransfer|ABlockTransfer|ABlockTransfer| ABlock|BBlockTransfer|BBlockTransfer|BBlockTransfer|BBlockTransfer|BBlockTransfer|BBlockTransfer| BBlock| CShuffle| CShuffle|CShuffleBlockTransfer|CDEShuffleBlockTransfer| BlkGemm| BlkGemm|ComputeTypeA|ComputeTypeB|LDSTypeA|LDSTypeB|
// ######| | | | | | DataType| | DataType| | | DataType| | Operation| Operation| Operation| | Size| | | | Per| Per| Per| | | Per| Per| Per| Per| ThreadCluster| ThreadCluster|SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar|LdsExtraM| ThreadCluster| ThreadCluster|SrcAccessOrder| SrcVector| SrcScalar| DstScalar|LdsExtraN| MXdl| NXdl| ClusterLengths| Scalar| PipeSched| PipelineVer| | | | |
// ######| | | | | | | | | | | | | | | | | | | | |Block|Block| Block| | | XDL| XDL|Wave|Wave| Lengths| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths| ArrangeOrder| | Dim| PerVector| PerVector_BK1| | PerWave| PerWave| MBlock_MPerBlock| PerVectors| | | | | | |
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | AK0_M_AK1| | | | | | | BK0_N_BK1| | | | | |PerShuffle|PerShuffle| NBlock_NPerBlock| | | | | | | |
< ALayout, BLayout, DsLayout, ELayout, ADataType, XDataType, BDataType, XDataType, DsDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, Scale_Block_M, Scale_Block_N, Scale_Block_K, 128, 128, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlkGemmPSched, BlkGemmPVer, float, float, float, float>;
// clang-format on
static constexpr ck::index_t KPerBlock = 64;
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3<
ALayout, // ALayout
BLayout, // BLayout
CLayout, // CLayout
ADataType, // ADataType
XDataType, // AScaleDataType
BDataType, // BDataType
XDataType, // BScaleDataType
CDataType, // CDataType
AccDataType, // GemmAccDataType
CShuffleDataType, // CShuffleDataType
AElementOp, // AElementwiseOperation
BElementOp, // BElementwiseOperation
CElementOp, // CElementwiseOperation
GemmSpec, // GemmSpec
MXVectorSize, // ScaleBlockSize: Scaling block size
256, // BlockSize: Thread block size
128, // MPerBlock
128, // NPerBlock
KPerBlock, // KPerBlock
16, // AK1
16, // BK1
32, // MPerXDL
32, // NPerXDL
2, // MXdlPerWave
2, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
16, // ABlockTransferSrcScalarPerVector
16, // ABlockTransferDstScalarPerVector_AK1
false, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
16, // BBlockTransferSrcScalarPerVector
16, // BBlockTransferDstScalarPerVector_BK1
false, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
BlkGemmPSched, // BlkGemmPipeSched
BlkGemmPVer, // BlkGemmPipelineVer
ADataType, // ComputeTypeA
BDataType // ComputeTypeB
>;
auto M = problem_size.M;
auto N = problem_size.N;
@@ -156,6 +181,7 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
auto StrideA = problem_size.StrideA;
auto StrideB = problem_size.StrideB;
auto StrideC = problem_size.StrideC;
auto KBatch = problem_size.KBatch;
auto f_host_tensor_descriptor =
[](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) {
@@ -191,21 +217,27 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
StrideC = f_get_default_stride(M, N, StrideC, CLayout{});
if(K % Scale_Block_K != 0)
if(K % ScaleBlockSize != 0)
{
throw std::runtime_error("wrong! K must be multiple of Scale_Block_K (16 or 32)");
throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize.");
};
auto Scale_Stride_AM = f_get_default_stride(M, K / Scale_Block_K, StrideA, ALayout{});
auto Scale_Stride_BN = f_get_default_stride(K / Scale_Block_K, N, StrideB, BLayout{});
// Hardcode scale layouts as per pipeline assumptions
// TODO: Change default scale layouts to Col for A and Row for B
// TODO: Allow user to specify scale layouts
using AScaleLayout = Row;
using BScaleLayout = Col;
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
auto Scale_Stride_AM = f_get_default_stride(M, K / ScaleBlockSize, -1, AScaleLayout{});
auto Scale_Stride_BN = f_get_default_stride(K / ScaleBlockSize, N, -1, BScaleLayout{});
Tensor<XDataType> a_m_k_scale(
f_host_tensor_descriptor(M, K / Scale_Block_K, Scale_Stride_AM, ALayout{})); // scales for A
Tensor<XDataType> b_k_n_scale(
f_host_tensor_descriptor(K / Scale_Block_K, N, Scale_Stride_BN, BLayout{})); // scales for B
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, AScaleLayout{}));
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BScaleLayout{}));
Tensor<XDataType> a_m_k_scale(f_host_tensor_descriptor(
M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{})); // scales for A
Tensor<XDataType> b_k_n_scale(f_host_tensor_descriptor(
K / ScaleBlockSize, N, Scale_Stride_BN, BScaleLayout{})); // scales for B
Tensor<CDataType> c_m_n_host_result(
f_host_tensor_descriptor(M, N, StrideC, CLayout{})); // host verification
@@ -223,28 +255,37 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
switch(config.init_method)
{
case 0:
if(config.verbosity > 0)
{
std::cout << "NOTE: No input data initialization." << std::endl;
}
break;
case 1:
case 2:
case 0: // Initializations for development and debugging
ck::utils::FillConstant<ADataType>{ck::type_convert<ADataType>(1.0f)}(a_m_k);
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(0.5f)}(a_m_k_scale);
ck::utils::FillConstant<BDataType>{ck::type_convert<BDataType>(1.0f)}(b_k_n);
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(2.0f)}(b_k_n_scale);
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(2.0f)}(a_m_k_scale);
ck::utils::FillConstant<BDataType>{ck::type_convert<BDataType>(0.5f)}(b_k_n);
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(1.0f)}(b_k_n_scale);
if(config.verbosity > 0)
{
std::cout << "Init A = {1}" << std::endl;
std::cout << "Init A scale = {0.5}" << std::endl;
std::cout << "Init B = {1}" << std::endl;
std::cout << "Init B scale = {2.0}" << std::endl;
std::cout << "Init A scale = {2.0}" << std::endl;
std::cout << "Init B = {0.5}" << std::endl;
std::cout << "Init B scale = {1.0}" << std::endl;
std::cout << "Expect C = {K}" << std::endl;
}
break;
case 1:
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.0f, 4.0f}(a_m_k);
ck::utils::FillUniformDistributionIntegerValue<XDataType>{-1.0f, 1.0f}(a_m_k_scale);
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-4.0f, 5.0f}(b_k_n);
ck::utils::FillUniformDistributionIntegerValue<XDataType>{-1.0f, 1.0f}(b_k_n_scale);
break;
case 2:
a_m_k.GenerateTensorValue(GeneratorTensor_3<BDataType>{-2.0, 2.0});
a_m_k_scale.GenerateTensorValue(GeneratorTensor_3<XDataType>{-1.0f, 1.0f});
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-2.0, 2.0});
b_k_n_scale.GenerateTensorValue(GeneratorTensor_3<XDataType>{-1.0f, 1.0f});
break;
default:
if(config.verbosity > 0)
{
@@ -269,31 +310,31 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
if(config.verbosity > 0)
std::cout << "Done." << std::endl;
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto cde_element_op = CDEElementOp{};
auto a_element_op = AElementOp{};
auto b_element_op = BElementOp{};
auto c_element_op = CElementOp{};
constexpr ck::index_t NumDTensor = DsDataType::Size();
// do GEMM
// run GEMM
auto device_op = DeviceOpInstance{};
auto invoker = device_op.MakeInvoker();
auto argument = device_op.MakeArgument(a_device_buf.GetDeviceBuffer(),
b_device_buf.GetDeviceBuffer(),
std::array<const void*, NumDTensor>{},
c_device_buf.GetDeviceBuffer(),
M,
N,
K,
StrideA,
StrideB,
std::array<ck::index_t, NumDTensor>{},
StrideC,
a_scale_device_buf.GetDeviceBuffer(),
b_scale_device_buf.GetDeviceBuffer(),
a_element_op,
b_element_op,
cde_element_op);
auto argument =
device_op.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
static_cast<XDataType*>(a_scale_device_buf.GetDeviceBuffer()),
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
static_cast<XDataType*>(b_scale_device_buf.GetDeviceBuffer()),
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
M,
N,
K,
StrideA,
Scale_Stride_AM,
StrideB,
Scale_Stride_BN,
StrideC,
KBatch,
a_element_op,
b_element_op,
c_element_op);
if(!device_op.IsSupportedArgument(argument))
{
@@ -303,7 +344,10 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
}
if(config.verbosity > 0)
std::cout << "Computing GEMM on device..." << std::endl;
{
std::cout << "Computing GEMM on device..." << std::endl << std::endl;
}
float ave_time =
invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, config.verbosity, 20, 50});
@@ -321,7 +365,7 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
BDataType,
CDataType,
AccDataType,
float,
XDataType,
PassThrough,
PassThrough,
PassThrough,
@@ -347,12 +391,15 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
std::cout << "Comparing results..." << std::endl;
}
if(config.init_method == 1)
if(config.init_method == 0)
{
res_verified =
res_verified && std::abs(static_cast<float>(K) - c_m_n_device_result(0, 0)) <= 0.0f;
std::cout << "Expected vs Computed: " << 1.0f * K << " vs " << c_m_n_device_result(0, 0)
<< ((res_verified) ? " (PASSED!)" : " (FAILED!)") << std::endl;
auto expected = static_cast<float>(K);
auto computed = type_convert<float>(c_m_n_device_result(1, 12));
res_verified = res_verified && std::abs(expected - computed) <= 0.0f;
std::cout << "\nExpected vs Computed: " << expected << " vs " << computed
<< ((res_verified) ? " (PASSED!)" : " (FAILED!)") << std::endl
<< std::endl;
}
res_verified = res_verified && ck::utils::check_err(c_m_n_device_result,
@@ -360,7 +407,7 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
"Error: Incorrect results!");
if(config.verbosity > 0 && res_verified)
std::cout << "Done." << std::endl;
std::cout << "Verification Successful!" << std::endl;
}
else
{
@@ -370,17 +417,18 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
if(config.time_kernel)
{
std::size_t flop = std::size_t(2) * M * N * K + M * K + K * N; // GEMM + A scale + B scale
std::size_t flop = std::size_t(2) * M * N * K +
std::size_t(2) * M * N * K / ScaleBlockSize; // GEMM + A scale + B scale
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
sizeof(CDataType) * M * N +
sizeof(XDataType) * (M * K + K * N) / Scale_Block_K;
sizeof(XDataType) * (M * K + K * N) / ScaleBlockSize;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_btype / 1.E6 / ave_time;
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
<< " GB/s" << std::endl;
<< " GB/s, " << device_op.GetTypeString() << std::endl;
}
return res_verified;
@@ -393,13 +441,15 @@ template <typename ADataType,
typename ALayout,
typename BLayout,
typename CLayout,
typename CElementWiseOp,
typename AElementOp,
typename BElementOp,
typename CElementOp,
typename AccDataType,
typename CShuffleDataType,
ck::index_t MXVectorSize>
bool run_mx_gemm_example(int argc, char* argv[])
{
ProblemSize problem_size;
ProblemSizeSplitK problem_size;
ExecutionConfig config;
return parse_cmd_args(argc, argv, problem_size, config) &&
@@ -410,7 +460,9 @@ bool run_mx_gemm_example(int argc, char* argv[])
ALayout,
BLayout,
CLayout,
CElementWiseOp,
AElementOp,
BElementOp,
CElementOp,
AccDataType,
CShuffleDataType,
MXVectorSize>(problem_size, config);

View File

@@ -5,23 +5,24 @@
using ADataType = ck::f8_t;
using BDataType = ck::f8_t;
#if 1
// XXX: MX-native GEMM kernel will work with e8m0_bexp_t scale type
using XDataType = float;
#else
using XDataType = ck::e8m0_bexp_t;
#endif
// TODO: Enable e8m0_bexp_t and FP8 scale types
using XDataType = ck::half_t;
// using XDataType = ck::e8m0_bexp_t;
using CDataType = ck::half_t;
using AccDataType = float;
using CShuffleDataType = float;
using CDataType = float;
using CShuffleDataType = CDataType;
using ALayout = Row;
using BLayout = Col;
using CLayout = Row;
using AElementOp = PassThrough; // elementwise transformation for A matrix
using BElementOp = PassThrough; // elementwise transformation for B matrix
using CElementOp = PassThrough; // elementwise transformation for C matrix
constexpr ck::index_t mx_vector_size = 128; // scaling block size
constexpr ck::index_t mx_vector_size = 32; // scaling block size
int main(int argc, char* argv[])
{
@@ -32,6 +33,8 @@ int main(int argc, char* argv[])
ALayout,
BLayout,
CLayout,
AElementOp,
BElementOp,
CElementOp,
AccDataType,
CShuffleDataType,

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -181,6 +181,23 @@ struct BlockwiseGemmXdlops_pipeline_base
using Tuple4 = decltype(CalculateAThreadOriginDataIndex());
/**
* @brief Constructor for BlockwiseGemmXdlops_pipeline_base.
*
* This constructor initializes the thread copy objects for matrices A and B.
* It also performs several compile-time checks to ensure the correctness of the
* matrix tile descriptors.
*
* @param a_origin The origin data index for matrix A.
* @param b_origin The origin data index for matrix B.
*
* @note The constructor includes static assertions to ensure that:
* - The matrix tile descriptors for A and B are known at compile-time.
* - The number of threads in the thread block matches the product of MWaves, NWaves, and
* WaveSize.
* - The dimensions of the block are divisible by the product of the corresponding XDL and
* repeat dimensions.
*/
__host__ __device__
BlockwiseGemmXdlops_pipeline_base(Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
Tuple4 b_origin = CalculateBThreadOriginDataIndex())

View File

@@ -0,0 +1,69 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_mx.hpp"
namespace ck {
template <BlockGemmPipelineVersion BlkGemmPipelineVer,
BlockGemmPipelineScheduler BlkGemmPipeSche,
index_t ThreadBlockSize,
index_t ScaleBlockSize,
typename ADataType,
typename AScaleDataType,
typename BDataType,
typename BScaleDataType,
typename ComputeDataType, // TODO: remove this as in this pipeline ADataType and BDataType
// must be used for compute
typename AccDataType,
typename ATileDesc,
typename BTileDesc,
typename AMmaTileDesc,
typename BMmaTileDesc,
index_t ABlockTransferSrcScalarPerVector,
index_t BBlockTransferSrcScalarPerVector,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat,
index_t NRepeat,
index_t KPack>
constexpr auto BlockGemmMXPipeline_Selector()
{
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
return BlockwiseGemmXdlops_pipeline_v1_mx<BlkGemmPipeSche,
ThreadBlockSize,
ScaleBlockSize,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>{};
}
else
{
std::cerr << "BlockGemmPipeline configuration is not available" << std::endl;
}
}
} // namespace ck

View File

@@ -0,0 +1,617 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp"
namespace ck {
// Naive pipeline with lowest resource request per WGP
// GlobalPrefetchStages: 1
// LocalPreFillStages: 1
// LocalPreFetchStages: 0
// LocalSharedMemoryBuffer: 1
template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
index_t ThreadBlockSize,
index_t ScaleBlockSize,
typename ADataType,
typename AScaleDataType,
typename BDataType,
typename BScaleDataType,
typename ComputeDataType,
typename AccDataType,
typename ATileDesc,
typename BTileDesc,
typename AMmaTileDesc,
typename BMmaTileDesc,
index_t ABlockTransferSrcScalarPerVector,
index_t BBlockTransferSrcScalarPerVector,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat, // MXdlPerWave
index_t NRepeat, // NXdlPerWave
index_t KPack>
struct BlockwiseGemmXdlops_pipeline_v1_mx
{
};
template <index_t ThreadBlockSize,
index_t ScaleBlockSize,
typename ADataType,
typename AScaleDataType,
typename BDataType,
typename BScaleDataType,
typename ComputeDataType,
typename AccDataType,
typename ATileDesc,
typename BTileDesc,
typename AMmaTileDesc,
typename BMmaTileDesc,
index_t ABlockTransferSrcScalarPerVector,
index_t BBlockTransferSrcScalarPerVector,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t MPerXDL,
index_t NPerXDL,
index_t MRepeat, // MXdlPerWave
index_t NRepeat, // NXdlPerWave
index_t KPack>
struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
ThreadBlockSize,
ScaleBlockSize,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>
: BlockwiseGemmXdlops_pipeline_base<ThreadBlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>
{
using Base = BlockwiseGemmXdlops_pipeline_base<ThreadBlockSize,
ADataType,
BDataType,
ComputeDataType,
AccDataType,
ATileDesc,
BTileDesc,
AMmaTileDesc,
BMmaTileDesc,
ABlockTransferSrcScalarPerVector,
BBlockTransferSrcScalarPerVector,
MPerBlock,
NPerBlock,
KPerBlock,
MPerXDL,
NPerXDL,
MRepeat,
NRepeat,
KPack>;
using Base::I0;
using Base::I1;
using Base::KRepeat;
using Base::MWaves;
using Base::NWaves;
using Base::WaveSize;
using Base::xdlops_gemm;
using Base::CalculateCThreadOriginDataIndex;
using Base::CalculateCThreadOriginDataIndex8D;
using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
using Base::GetCThreadBuffer;
using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
using Base::GetWaveIdx;
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
using Base::a_block_desc_m0_m1_m2_k;
using Base::b_block_desc_n0_n1_n2_k;
using Base::AMmaKStride;
using Base::BMmaKStride;
using Tuple4 = typename Base::Tuple4;
static constexpr index_t PrefetchStages = 1;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
static constexpr auto ScalesPerKBlockSize =
KPerBlock / ScaleBlockSize; // How many mx-vectors per K block size
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
{
ignore = num_loop;
return TailNumber::Full;
}
__device__ static auto CalculateAThreadOriginDataIndex()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_m = wave_idx[I0];
const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
return make_tuple(0, waveId_m, xdlops_a_idx[I1], xdlops_gemm.KPerXdlops * xdlops_a_idx[I0]);
}
__device__ static auto CalculateBThreadOriginDataIndex()
{
const auto wave_idx = GetWaveIdx();
const auto waveId_n = wave_idx[I1];
const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
return make_tuple(0, waveId_n, xdlops_b_idx[I1], xdlops_gemm.KPerXdlops * xdlops_b_idx[I0]);
}
/**
* @brief Constructor for BlockwiseGemmXdlops_pipeline_v1_mx.
*
* The primary purpose of this constructor is to modify default initialization of the base class
* with the origin data index suitable for microscaling.
*
* @param a_origin The origin data index for matrix A.
* @param b_origin The origin data index for matrix B.
*
*/
__host__ __device__
BlockwiseGemmXdlops_pipeline_v1_mx(Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
Tuple4 b_origin = CalculateBThreadOriginDataIndex())
: Base(a_origin, b_origin)
{
}
template <bool HasMainLoop,
TailNumber TailNum,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer,
typename AScaleGridBuffer,
typename AScaleGridDesc,
typename AScaleThreadTransfer,
typename BScaleGridBuffer,
typename BScaleGridDesc,
typename BScaleThreadTransfer>
__device__ void Run(
// ABlockCopy
const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
// BBlockCopy
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
// CThread
CThreadBuffer& c_thread_buf,
// A and B scales
const AScaleGridDesc& a_scale_grid_desc,
AScaleThreadTransfer& a_scale_thread_copy,
const AScaleGridBuffer& a_scale_grid_buf,
const BScaleGridDesc& b_scale_grid_desc,
BScaleThreadTransfer& b_scale_thread_copy,
const BScaleGridBuffer& b_scale_grid_buf,
index_t num_loop) const
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
b_thread_desc_.GetElementSpaceSize());
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
a_scale_thread_desc.GetElementSpaceSize());
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
b_scale_thread_desc.GetElementSpaceSize());
// Global prefetch 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_assert(xdlops_gemm.mfma_instr.num_groups_per_blk *
xdlops_gemm.mfma_instr.group_size ==
xdlops_gemm.GetRegSizePerXdlops(),
"Assume num_regs_per_blk == num_groups_per_blk * group_size");
// Prefetch a_scales
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, xdlops_gemm.mfma_instr.num_groups_per_blk, 1>{}([&](auto g) {
auto a_scale_thread_buf_group =
make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
a_scale_thread_desc_group.GetElementSpaceSize());
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc_group,
make_tuple(I0, I0),
a_scale_thread_buf_group);
static_for<0, xdlops_gemm.mfma_instr.group_size, 1>{}([&](auto i) {
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, g, i));
a_scale_thread_buf(Number<a_scale_offset>{}) =
a_scale_thread_buf_group[Number<i>{}];
});
// go to the next group
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc,
make_multi_index(2 * xdlops_gemm.mfma_instr.group_size, 0));
}); // g
// restore row id and advance to the next scale
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc,
make_multi_index(-2 * xdlops_gemm.mfma_instr.group_size *
xdlops_gemm.mfma_instr.num_groups_per_blk,
1));
}); // k0
// restore column id and advance to the next set of rows
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize));
}); // m0
// restore row id and advance to the next set of scales
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
make_multi_index(-MPerBlock, ScalesPerKBlockSize));
// Prefetch b_scales
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc,
make_tuple(n0, I0),
b_scale_thread_buf);
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
make_multi_index(NWaves * NPerXDL, 0));
});
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
make_multi_index(-NPerBlock, ScalesPerKBlockSize));
// Local prefill 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// Initialize C
c_thread_buf.Clear();
auto c_thread_buf_per_scale = remove_cvref_t<decltype(c_thread_buf)>();
// main body
if constexpr(HasMainLoop)
{
// loop over k with the step KPerBlock
index_t i = 0;
do
{
// -------------------------------------------------------------------------------------------
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto a_k_step = k * AMmaKStride * KPack / xdlops_gemm.K1PerXdlops;
constexpr auto b_k_step = k * BMmaKStride * KPack / xdlops_gemm.K1PerXdlops;
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<a_k_step>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<b_k_step>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_buf);
});
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
c_thread_buf_per_scale.Clear();
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
// MFMA accumulation
// m = 1:MPerXDL
// n = 1:NPerXDL
// k = 1:KPack
// c(m,n) += a(m,k)*b(k,n)
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
// one scale per k0
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0));
static_for<0, xdlops_gemm.mfma_instr.num_groups_per_blk, 1>{}(
[&](auto g) {
static_for<0, xdlops_gemm.mfma_instr.group_size, 1>{}(
[&](auto r) {
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(
make_tuple(m0, k0, g, r));
constexpr auto reg_offset =
g * xdlops_gemm.mfma_instr.group_size + r;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, reg_offset));
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale[Number<reg_offset>{}] *
type_convert<AccDataType>(
b_scale_thread_buf[Number<b_scale_offset>{}]) *
type_convert<AccDataType>(
a_scale_thread_buf[Number<a_scale_offset>{}]);
});
});
});
});
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, xdlops_gemm.mfma_instr.num_groups_per_blk, 1>{}([&](auto g) {
auto a_scale_thread_buf_group =
make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
a_scale_thread_desc_group.GetElementSpaceSize());
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc_group,
make_tuple(I0, I0),
a_scale_thread_buf_group);
static_for<0, xdlops_gemm.mfma_instr.group_size, 1>{}([&](auto r) {
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, g, r));
a_scale_thread_buf(Number<a_scale_offset>{}) =
a_scale_thread_buf_group[Number<r>{}];
});
// go to the next group
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc,
make_multi_index(2 * xdlops_gemm.mfma_instr.group_size, 0));
}); // g
// restore row id and advance to the next scale
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc,
make_multi_index(-2 * xdlops_gemm.mfma_instr.group_size *
xdlops_gemm.mfma_instr.num_groups_per_blk,
1));
}); // k0
// restore column id and advance to the next set of rows
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc,
make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize));
}); // m0
// restore row id and advance to the next set of scales
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, make_multi_index(-MPerBlock, ScalesPerKBlockSize));
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc,
make_tuple(n0, I0),
b_scale_thread_buf);
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
make_multi_index(NWaves * NPerXDL, 0));
});
// NWaves * NPerXDL * NRepeat == NPerBlock
b_scale_thread_copy.MoveSrcSliceWindow(
b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize));
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
i += 1;
} while(i < (num_loop - 1));
}
// tail
if constexpr(TailNum == TailNumber::Full)
{
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
constexpr auto a_k_step = k * AMmaKStride * KPack / xdlops_gemm.K1PerXdlops;
constexpr auto b_k_step = k * BMmaKStride * KPack / xdlops_gemm.K1PerXdlops;
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<a_k_step>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<b_k_step>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_buf);
});
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
c_thread_buf_per_scale.Clear();
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
// one scale per k0
constexpr index_t b_scale_offset =
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0));
static_for<0, xdlops_gemm.mfma_instr.num_groups_per_blk, 1>{}([&](auto g) {
static_for<0, xdlops_gemm.mfma_instr.group_size, 1>{}([&](auto r) {
constexpr index_t a_scale_offset =
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, g, r));
constexpr auto reg_offset =
g * xdlops_gemm.mfma_instr.group_size + r;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, reg_offset));
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale[Number<reg_offset>{}] *
type_convert<AccDataType>(
b_scale_thread_buf[Number<b_scale_offset>{}]) *
type_convert<AccDataType>(
a_scale_thread_buf[Number<a_scale_offset>{}]);
});
});
});
});
});
}
}
// TODO: make this field protected when a_scale_thread_copy_ is moved here
static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<MRepeat>{},
Number<KRepeat>{},
Number<xdlops_gemm.mfma_instr.num_groups_per_blk>{},
Number<xdlops_gemm.mfma_instr.group_size>{}));
// Is used to copy data from a_scale_grid to a_scale_thread
static constexpr auto a_scale_thread_desc_group = make_naive_tensor_descriptor_packed(
make_tuple(Number<xdlops_gemm.mfma_instr.group_size>{}, Number<1>{}));
// TODO: make this field protected when b_scale_thread_copy_ is moved here
static constexpr auto b_scale_thread_desc =
make_naive_tensor_descriptor_packed(make_tuple(Number<NRepeat>{}, Number<KRepeat>{}));
protected:
using Base::a_thread_copy_;
using Base::a_thread_desc_;
using Base::b_thread_copy_;
using Base::b_thread_desc_;
using Base::c_thread_desc_;
};
} // namespace ck

View File

@@ -0,0 +1,50 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename AScaleDataType,
typename BDataType,
typename BScaleDataType,
typename CDataType,
index_t ScaleBlockSize,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation>
struct DeviceGemmMX : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_a_scale,
const void* p_b,
const void* p_b_scale,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideAScale,
ck::index_t StrideB,
ck::index_t StrideBScale,
ck::index_t StrideC,
ck::index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,877 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/host_utility/flush_cache.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_mx.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
/**
* \brief WIP: Implements XDL CShuffle V3 GEMM for microscale-compliant data types
*
* This class is a work-in-progress implementation of the XDL CShuffle V3 GEMM for
* microscale-compliant data types.
*
* Assumptions:
* - A and B data types are compliant with the OCP Microscaling Formats (MX) Specification
* - Each scale applies to ScaleBlockSize elements in K direction
* - A scale matrix is row-major
* - B scale matrix is column-major
* - Scale data types must have get_exponent_value() specialization, whereas lowest 8 bits of the
* exponent will be interpreted as conventional biased Float32 exponent (E8M0)
*
* Tunable parameters.
* The CK instance includes a series of tunable template parameters to control the parallel
* granularity of the workload to achieve load balancing on different hardware platforms. These
* parameters include Block Size, M/N/K Per Block, M/N per XDL, AK1, BK1, etc.
* - Block Size determines the number of threads in the thread block.
* - M/N/K Per Block determines the size of tile that each thread block is responsible for
* calculating.
* - M/N Per XDL refers to M/N size for Instinct accelerator Matrix Fused Multiply Add (MFMA)
* instructions operating on a per-wavefront basis.
* - A/B K1 is related to the data type. It can be any value ranging from 1 to K Per Block. To
* achieve the optimal load/store performance, 128bit per load is suggested. In addition, the A/B
* loading parameters must be changed accordingly to match the A/B K1 value; otherwise, it will
* result in compilation errors.
*
* Conditions for achieving computational load balancing on different hardware platforms can vary.
*
* Serialized version of the algorithm:
* \code
* // E = A * B + C
* // Loop over E[MPerBlock,NPerBlock] tiles
* for(int mb = 0; mb < M; mb += MPerBlock){
* for(int nb = 0; nb < N; nb += NPerBlock){
* // initialize E[MPerBlock,NPerBlock] tile
* for(int mt = mb; mt < mb + MPerBlock; mt++){
* for(int nt = nb; nt < nb + NPerBlock; nt++){
* E[mt,nt] = C[mt,nt];
* }
* }
*
* // multiply-accumulate per tile
* for(int kb = 0; kb < K; kb += KPerBlock){
* for(int m0 = mb; m0 < mb + MPerBlock; m0 += MWaves * MPerXDL){
* for(int n0 = nb; n0 < nb + NPerBlock; n0 += NWaves * NPerXDL){
* for(int mw = m0; mw < m0 + MWaves * MPerXDL; mw += MPerXDL){
* for(int nw = n0; nw < n0 + NWaves * NPerXDL; nw += NPerXDL){
* for(int k0 = kb; k0 < kb + KPerBlock; k0 += mfma.num_input_blks*KPack){
* // MFMA accumulation for multirate instructions
* for(int k_pack = k0; k_pack < k0 + mfma.num_input_blks*KPack; k_pack += KPack){
* for(int k_mfma = k_pack; k_mfma < k_pack + KPack; k_mfma += mfma.k_per_blk){
* // MFMA instruction
* for(int m = mw; m < mw + MPerXDL; m++){
* for(int n = nw; n < nw + NPerXDL; n++){
* for(int k = k_mfma; k < k_mfma + mfma.k_per_blk; k++){
* E[m,n] += A[m,k] * B[k,n];
* }
* }
* }
* }
* }
* }
* }
* }
* }
* }
* }
* }
* }
* \endcode
*
*/
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename AScaleDataType,
typename BDataType,
typename BScaleDataType,
typename CDataType,
typename GemmAccDataType,
typename CShuffleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
index_t ScaleBlockSize, // Scaling block size
index_t BlockSize, // Thread block size
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerXDL,
index_t NPerXDL,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
typename ComputeTypeA =
ADataType, // XXX: These should always be the same as ADataType and BDataType
typename ComputeTypeB =
BDataType // TODO: Hardcode them and remove from the list of template parameters
>
struct DeviceGemmMX_Xdl_CShuffleV3 : public DeviceGemmMX<ALayout,
BLayout,
CLayout,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
CDataType,
ScaleBlockSize,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMX_xdl_cshuffle_v3<
ALayout,
BLayout,
CLayout,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
GemmAccDataType,
CShuffleDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
ScaleBlockSize,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
using Argument = typename GridwiseGemm::Argument;
// Invoker
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
{
arg.Print();
GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print();
}
if(!GridwiseGemm::CheckValidity(arg))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
float ave_time = 0;
index_t k_grain = arg.KBatch * KPerBlock;
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
const auto Run = [&](const auto& kernel) {
if(stream_config.flush_cache)
{
Argument arg_ = arg;
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
auto size_a_buffer =
a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType);
auto size_b_buffer =
b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType);
ck::utility::RotatingMemWrapper<Argument> rotating_mem(
arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer);
rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck::utility::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if(arg_.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
0,
arg_.M * arg_.N * sizeof(CDataType),
stream_config.stream_id_));
};
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
stream_config,
run_flush_cache,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg_);
}
else
{
if(arg.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
0,
arg.M * arg.N * sizeof(CDataType),
stream_config.stream_id_));
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
}
};
// TODO: Check if this is the right algorithm for minimum_occupancy
constexpr index_t minimum_occupancy =
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave
? (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 &&
MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) <= 128 * 128 * 64 * 2)
? 2
: 1
: 2;
if(has_main_k_block_loop)
{
// Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(arg.KBatch > 1)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
// Tail number could be One to Seven
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
{
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::One>;
Run(kernel);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Two>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Three>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Four)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Five)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Seven)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::One>;
Run(kernel);
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Full)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Full>;
Run(kernel);
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Two>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Three)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Three>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Four)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Four>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Five)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Five>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Six>;
Run(kernel);
}
}
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Seven)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Seven>;
Run(kernel);
}
}
}
}
// Tail number could be Odd or Even
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
else
{
if(arg.KBatch > 1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
}
else
{
// Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(arg.KBatch > 1)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
static_assert((is_same_v<ADataType, f8_t> || is_same_v<ADataType, bf8_t> ||
is_same_v<ADataType, f6_t> || is_same_v<ADataType, bf6_t> ||
is_same_v<ADataType, f4_t>)&&(is_same_v<BDataType, f8_t> ||
is_same_v<BDataType, bf8_t> ||
is_same_v<BDataType, f6_t> ||
is_same_v<BDataType, bf6_t> ||
is_same_v<BDataType, f4_t>),
"Only microscaling formats are supported for ADataType and BDataType");
static_assert(ScaleBlockSize == 32, "Only ScaleBlockSize 32 is supported");
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
{
return false;
}
if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
{
return false;
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding))
{
return false;
}
return GridwiseGemm::CheckValidity(arg);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const ADataType* p_a,
const AScaleDataType* p_a_scale,
const BDataType* p_b,
const BScaleDataType* p_b_scale,
CDataType* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideScaleA,
index_t StrideB,
index_t StrideScaleB,
index_t StrideC,
index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{p_a,
p_a_scale,
p_b,
p_b_scale,
p_c,
M,
N,
K,
StrideA,
StrideScaleA,
StrideB,
StrideScaleB,
StrideC,
KBatch,
a_element_op,
b_element_op,
c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_a_scale,
const void* p_b,
const void* p_b_scale,
void* p_c,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideScaleA,
ck::index_t StrideB,
ck::index_t StrideScaleB,
ck::index_t StrideC,
ck::index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const AScaleDataType*>(p_a_scale),
static_cast<const BDataType*>(p_b),
static_cast<const BScaleDataType*>(p_b_scale),
static_cast<CDataType*>(p_c),
M,
N,
K,
StrideA,
StrideScaleA,
StrideB,
StrideScaleB,
StrideC,
KBatch,
a_element_op,
b_element_op,
c_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
{BlockGemmPipelineVersion::v1, "v1"},
{BlockGemmPipelineVersion::v2, "v2"},
{BlockGemmPipelineVersion::v3, "v3"},
{BlockGemmPipelineVersion::v4, "v4"},
{BlockGemmPipelineVersion::v5, "v5"}};
// clang-format off
str << "DeviceGemmMX_Xdl_CShuffleV3"
<< "<"
<< getGemmSpecializationString(GemmSpec) << ", "
<< std::string(ALayout::name)[0]
<< std::string(BLayout::name)[0]
<< std::string(CLayout::name)[0]
<< ">"
<< " BlkSize: "
<< BlockSize << ", "
<< "BlkTile: "
<< MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
<< "WaveTile: "
<< MPerXDL<<"x"<<NPerXDL << ", "
<< "WaveMap: "
<< MXdlPerWave<<"x" << NXdlPerWave<<", "
<< "VmemReadVec: "
<< ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
<< "BlkGemmPipelineScheduler: "
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
<< "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
<< "BlkGemmPipelinePrefetchStages: "
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", "
<< "Kpack: "
<< GridwiseGemm::BlockwiseGemmPipe::AMmaKStride << ", "
<< "ScaleBlockSize: "
<< ScaleBlockSize;
// clang-format on
return str.str();
}
REGISTER_EXTRA_PRINTING_METHODS
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

File diff suppressed because it is too large Load Diff

View File

@@ -189,15 +189,36 @@ struct ThreadwiseTensorSliceTransfer_v1r3
const ElementwiseOperation element_op_;
}; // namespace ThreadwiseTensorSliceTransfer_v1r3
// Assume:
// 1. src:
// 1. SrcDesc is not known at compile-time
// 2. SrcBuffer is DynamicBuffer
// 3. src_slice_origin_idx is not known at compile-time
// 2. dst:
// 1. DstDesc is known at compile-time
// 2. DstBuffer is StaticBuffer
// 3. dst_slice_origin_idx is known at compile-time
/**
* @brief Helper structure that facilitates transfer of source (grid) data to destination threads.
*
* @details The following assumptions are made:
* - For Source (Grid) Data:
* 1. The source tensor descriptor SrcDesc is not known at compile-time.
* 2. The source buffer is a dynamic buffer.
* 3. The source slice origin index src_slice_origin_idx is not known at compile-time.
* - For Destination (Thread) Data:
* 1. The destination tensor descriptor DstDesc is known at compile-time.
* 2. The destination buffer dst_buf is a static buffer.
* 3. The destination slice origin index dst_slice_origin_idx is known at compile-time.
*
* @tparam SrcData The data type of the source tensor.
* @tparam DstData The data type of the destination tensor.
* @tparam SrcDesc The descriptor type of the source tensor.
* @tparam DstDesc The descriptor type of the destination tensor.
* @tparam SliceLengths The lengths of the slice to be transferred.
* @tparam DimAccessOrder The order of dimension access for the space-filling curve.
* @tparam SrcVectorDim The dimension along which vectorized access is performed in the source
* tensor.
* @tparam SrcScalarPerVector The number of scalar elements per vector in the source tensor.
* @tparam SrcScalarStrideInVector The stride of scalar elements within a vector in the source
* tensor.
* @tparam SrcResetCoordinateAfterRun controls whether source coordinate is restored after each Run
* or rolled back one step in MoveSrcSliceWindow
* @tparam InvalidElementAsNaN Whether to fill invalid elements with NaN (only applicable for
* floating-point types).
*
*/
template <typename SrcData,
typename DstData,
typename SrcDesc,

View File

@@ -793,7 +793,7 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x64f8f6f4>
static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
static constexpr index_t m_per_blk = 32; // from the instruction
static constexpr index_t n_per_blk = 32; // from the instruction
static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 64 / num_input_blks
static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks
static constexpr bool is_k_reduction = true; // ???
// clang-format on
@@ -817,7 +817,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x128f8f6f4>
static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
static constexpr index_t m_per_blk = 16; // from the instruction
static constexpr index_t n_per_blk = 16; // from the instruction
static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 128 / num_input_blks
static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks
static constexpr bool is_k_reduction = true; // ???
// clang-format on
@@ -841,7 +841,7 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4>
static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
static constexpr index_t m_per_blk = 32; // from the instruction
static constexpr index_t n_per_blk = 32; // from the instruction
static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 64 / num_input_blks
static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks
static constexpr bool is_k_reduction = true; // ???
// clang-format on
@@ -870,7 +870,7 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>
static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
static constexpr index_t m_per_blk = 16; // from the instruction
static constexpr index_t n_per_blk = 16; // from the instruction
static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 128 / num_input_blks
static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks
static constexpr bool is_k_reduction = true; // ???
// clang-format on