mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Introduce MX GEMM for FP8 data type (#2000)
[ROCm/composable_kernel commit: 6660dc6b8e]
This commit is contained in:
committed by
GitHub
parent
fd151c05d9
commit
bbdd7f6d57
@@ -144,7 +144,7 @@ function(clang_tidy_check TARGET)
|
||||
# COMMAND ${CLANG_TIDY_COMMAND} $<JOIN:$<TARGET_PROPERTY:${TARGET},SOURCES>, >
|
||||
foreach(SOURCE ${SOURCES})
|
||||
if((NOT "${SOURCE}" MATCHES "(h|hpp|hxx)$") AND (NOT "${SOURCE}" MATCHES "TARGET_OBJECTS"))
|
||||
string(MAKE_C_IDENTIFIER "${SOURCE}" tidy_file)
|
||||
string(MD5 tidy_file "${SOURCE}")
|
||||
set(tidy_target tidy-target-${TARGET}-${tidy_file})
|
||||
add_custom_target(${tidy_target}
|
||||
# for some targets clang-tidy not able to get information from .clang-tidy
|
||||
|
||||
@@ -9,20 +9,17 @@
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/utility/blkgemmpipe_scheduler.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
#include "ck/utility/sequence.hpp"
|
||||
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp"
|
||||
|
||||
#include "ck/library/utility/check_err.hpp"
|
||||
#include "ck/library/utility/device_memory.hpp"
|
||||
#include "ck/library/utility/fill.hpp"
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
|
||||
using ScaleDataType = ck::e8m0_bexp_t;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
@@ -31,6 +28,8 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using ck::type_convert;
|
||||
|
||||
struct ExecutionConfig final
|
||||
{
|
||||
int do_verification = 1; // (0=no, 1=CPU)
|
||||
@@ -39,8 +38,9 @@ struct ExecutionConfig final
|
||||
int verbosity = 0; // (0=no info, 1=verbose info)
|
||||
};
|
||||
|
||||
struct ProblemSize final
|
||||
struct ProblemSizeSplitK final
|
||||
{
|
||||
|
||||
ck::index_t M = 3840;
|
||||
ck::index_t N = 4096;
|
||||
ck::index_t K = 4096;
|
||||
@@ -48,9 +48,14 @@ struct ProblemSize final
|
||||
ck::index_t StrideA = -1;
|
||||
ck::index_t StrideB = -1;
|
||||
ck::index_t StrideC = -1;
|
||||
|
||||
ck::index_t KBatch = 1;
|
||||
};
|
||||
|
||||
bool parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config)
|
||||
bool parse_cmd_args(int argc,
|
||||
char* argv[],
|
||||
ProblemSizeSplitK& problem_size,
|
||||
ExecutionConfig& config)
|
||||
{
|
||||
if(argc == 1)
|
||||
{
|
||||
@@ -63,7 +68,7 @@ bool parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, Execution
|
||||
config.time_kernel = std::stoi(argv[3]);
|
||||
config.verbosity = std::stoi(argv[4]);
|
||||
}
|
||||
else if(argc == 11)
|
||||
else if(argc >= 11)
|
||||
{
|
||||
config.do_verification = std::stoi(argv[1]);
|
||||
config.init_method = std::stoi(argv[2]);
|
||||
@@ -77,6 +82,11 @@ bool parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, Execution
|
||||
problem_size.StrideA = std::stoi(argv[8]);
|
||||
problem_size.StrideB = std::stoi(argv[9]);
|
||||
problem_size.StrideC = std::stoi(argv[10]);
|
||||
|
||||
if(argc >= 12)
|
||||
{
|
||||
problem_size.KBatch = std::stoi(argv[11]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -85,7 +95,8 @@ bool parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, Execution
|
||||
<< std::endl
|
||||
<< "arg3: time kernel (0=no, 1=yes)" << std::endl
|
||||
<< "arg4: verbosity (0=no info, 1=verbose info)" << std::endl
|
||||
<< "arg5 to 10: M (16x), N(16x), K(16x), StrideA, StrideB, StrideC" << std::endl;
|
||||
<< "arg5 to 10: M(256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl
|
||||
<< "arg11: KBatch" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -99,56 +110,70 @@ template <typename ADataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename CElementWiseOp,
|
||||
typename AElementOp,
|
||||
typename BElementOp,
|
||||
typename CElementOp,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
ck::index_t MXVectorSize>
|
||||
bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& config)
|
||||
{
|
||||
using ELayout = CLayout;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using DsDataType = ck::Tuple<>;
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CDEElementOp = CElementWiseOp;
|
||||
|
||||
static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
static constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3;
|
||||
static constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v1;
|
||||
|
||||
#if 1
|
||||
// XXX: These parameters should not exist in MX-native GEMM kernel
|
||||
static constexpr ck::index_t Scale_Block_M = 128;
|
||||
static constexpr ck::index_t Scale_Block_N = 128;
|
||||
#endif
|
||||
static constexpr ck::index_t Scale_Block_K = MXVectorSize;
|
||||
static constexpr ck::index_t ScaleBlockSize = MXVectorSize;
|
||||
|
||||
// XXX: DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 is not designed to utilize MX-specific MFMA
|
||||
// instructions.
|
||||
//
|
||||
// XXX: DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 is not designed to utilize device-optimized
|
||||
// scaled type convert functions.
|
||||
//
|
||||
// XXX: In DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3, KPerBlock is expected to be equal to
|
||||
// ScaleBlockK (aka MXVectorSize).
|
||||
// Additionally, the following is also expected:
|
||||
// static_assert(ScaleBlockM % MPerBlock == 0);
|
||||
// static_assert(ScaleBlockN % NPerBlock == 0);
|
||||
// In MX-native GEMM kernel these requirements should be relaxed.
|
||||
//
|
||||
// XXX: It appears, by default we are using mfma_f32_16x16x4xf32
|
||||
// MfmaSelector<ComputeTypeA, MPerXdl, NPerXdl, ComputeTypeB>::selected_mfma.k_per_blk =
|
||||
// MfmaSelector<float, 16, 16, float>::selected_mfma.k_per_blk = mfma_f32_16x16x4xf32
|
||||
// XXX: GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 assumes scale type is float
|
||||
|
||||
// clang-format off
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
|
||||
// ######| ALayout| BLayout| DsLayout| CLayout| ADataType| AScale| BDataType| BScale| DsDataType| CDataType| GemmAcc| CShuffleDataType|AElementwise|BElementwise| CElementwise| GemmSpec|Block| ScaleBlockM| ScaleBlockN| ScaleBlockK| M| N| K| AK1| BK1| M| N|MXdl|NXdl|ABlockTransfer|ABlockTransfer|ABlockTransfer|ABlockTransfer|ABlockTransfer|ABlockTransfer| ABlock|BBlockTransfer|BBlockTransfer|BBlockTransfer|BBlockTransfer|BBlockTransfer|BBlockTransfer| BBlock| CShuffle| CShuffle|CShuffleBlockTransfer|CDEShuffleBlockTransfer| BlkGemm| BlkGemm|ComputeTypeA|ComputeTypeB|LDSTypeA|LDSTypeB|
|
||||
// ######| | | | | | DataType| | DataType| | | DataType| | Operation| Operation| Operation| | Size| | | | Per| Per| Per| | | Per| Per| Per| Per| ThreadCluster| ThreadCluster|SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar|LdsExtraM| ThreadCluster| ThreadCluster|SrcAccessOrder| SrcVector| SrcScalar| DstScalar|LdsExtraN| MXdl| NXdl| ClusterLengths| Scalar| PipeSched| PipelineVer| | | | |
|
||||
// ######| | | | | | | | | | | | | | | | | | | | |Block|Block| Block| | | XDL| XDL|Wave|Wave| Lengths| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths| ArrangeOrder| | Dim| PerVector| PerVector_BK1| | PerWave| PerWave| MBlock_MPerBlock| PerVectors| | | | | | |
|
||||
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | AK0_M_AK1| | | | | | | BK0_N_BK1| | | | | |PerShuffle|PerShuffle| NBlock_NPerBlock| | | | | | | |
|
||||
< ALayout, BLayout, DsLayout, ELayout, ADataType, XDataType, BDataType, XDataType, DsDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, Scale_Block_M, Scale_Block_N, Scale_Block_K, 128, 128, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlkGemmPSched, BlkGemmPVer, float, float, float, float>;
|
||||
// clang-format on
|
||||
static constexpr ck::index_t KPerBlock = 64;
|
||||
using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3<
|
||||
ALayout, // ALayout
|
||||
BLayout, // BLayout
|
||||
CLayout, // CLayout
|
||||
ADataType, // ADataType
|
||||
XDataType, // AScaleDataType
|
||||
BDataType, // BDataType
|
||||
XDataType, // BScaleDataType
|
||||
CDataType, // CDataType
|
||||
AccDataType, // GemmAccDataType
|
||||
CShuffleDataType, // CShuffleDataType
|
||||
AElementOp, // AElementwiseOperation
|
||||
BElementOp, // BElementwiseOperation
|
||||
CElementOp, // CElementwiseOperation
|
||||
GemmSpec, // GemmSpec
|
||||
MXVectorSize, // ScaleBlockSize: Scaling block size
|
||||
256, // BlockSize: Thread block size
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
KPerBlock, // KPerBlock
|
||||
16, // AK1
|
||||
16, // BK1
|
||||
32, // MPerXDL
|
||||
32, // NPerXDL
|
||||
2, // MXdlPerWave
|
||||
2, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
16, // ABlockTransferSrcScalarPerVector
|
||||
16, // ABlockTransferDstScalarPerVector_AK1
|
||||
false, // ABlockLdsExtraM
|
||||
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
16, // BBlockTransferSrcScalarPerVector
|
||||
16, // BBlockTransferDstScalarPerVector_BK1
|
||||
false, // BBlockLdsExtraN
|
||||
1, // CShuffleMXdlPerWavePerShuffle
|
||||
1, // CShuffleNXdlPerWavePerShuffle
|
||||
S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
8, // CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
BlkGemmPSched, // BlkGemmPipeSched
|
||||
BlkGemmPVer, // BlkGemmPipelineVer
|
||||
ADataType, // ComputeTypeA
|
||||
BDataType // ComputeTypeB
|
||||
>;
|
||||
|
||||
auto M = problem_size.M;
|
||||
auto N = problem_size.N;
|
||||
@@ -156,6 +181,7 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
auto StrideA = problem_size.StrideA;
|
||||
auto StrideB = problem_size.StrideB;
|
||||
auto StrideC = problem_size.StrideC;
|
||||
auto KBatch = problem_size.KBatch;
|
||||
|
||||
auto f_host_tensor_descriptor =
|
||||
[](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) {
|
||||
@@ -191,21 +217,27 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
|
||||
StrideC = f_get_default_stride(M, N, StrideC, CLayout{});
|
||||
|
||||
if(K % Scale_Block_K != 0)
|
||||
if(K % ScaleBlockSize != 0)
|
||||
{
|
||||
throw std::runtime_error("wrong! K must be multiple of Scale_Block_K (16 or 32)");
|
||||
throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize.");
|
||||
};
|
||||
|
||||
auto Scale_Stride_AM = f_get_default_stride(M, K / Scale_Block_K, StrideA, ALayout{});
|
||||
auto Scale_Stride_BN = f_get_default_stride(K / Scale_Block_K, N, StrideB, BLayout{});
|
||||
// Hardcode scale layouts as per pipeline assumptions
|
||||
// TODO: Change default scale layouts to Col for A and Row for B
|
||||
// TODO: Allow user to specify scale layouts
|
||||
using AScaleLayout = Row;
|
||||
using BScaleLayout = Col;
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
auto Scale_Stride_AM = f_get_default_stride(M, K / ScaleBlockSize, -1, AScaleLayout{});
|
||||
auto Scale_Stride_BN = f_get_default_stride(K / ScaleBlockSize, N, -1, BScaleLayout{});
|
||||
|
||||
Tensor<XDataType> a_m_k_scale(
|
||||
f_host_tensor_descriptor(M, K / Scale_Block_K, Scale_Stride_AM, ALayout{})); // scales for A
|
||||
Tensor<XDataType> b_k_n_scale(
|
||||
f_host_tensor_descriptor(K / Scale_Block_K, N, Scale_Stride_BN, BLayout{})); // scales for B
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, AScaleLayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BScaleLayout{}));
|
||||
|
||||
Tensor<XDataType> a_m_k_scale(f_host_tensor_descriptor(
|
||||
M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{})); // scales for A
|
||||
Tensor<XDataType> b_k_n_scale(f_host_tensor_descriptor(
|
||||
K / ScaleBlockSize, N, Scale_Stride_BN, BScaleLayout{})); // scales for B
|
||||
|
||||
Tensor<CDataType> c_m_n_host_result(
|
||||
f_host_tensor_descriptor(M, N, StrideC, CLayout{})); // host verification
|
||||
@@ -223,28 +255,37 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
|
||||
switch(config.init_method)
|
||||
{
|
||||
case 0:
|
||||
if(config.verbosity > 0)
|
||||
{
|
||||
std::cout << "NOTE: No input data initialization." << std::endl;
|
||||
}
|
||||
break;
|
||||
case 1:
|
||||
case 2:
|
||||
case 0: // Initializations for development and debugging
|
||||
ck::utils::FillConstant<ADataType>{ck::type_convert<ADataType>(1.0f)}(a_m_k);
|
||||
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(0.5f)}(a_m_k_scale);
|
||||
ck::utils::FillConstant<BDataType>{ck::type_convert<BDataType>(1.0f)}(b_k_n);
|
||||
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(2.0f)}(b_k_n_scale);
|
||||
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(2.0f)}(a_m_k_scale);
|
||||
ck::utils::FillConstant<BDataType>{ck::type_convert<BDataType>(0.5f)}(b_k_n);
|
||||
ck::utils::FillConstant<XDataType>{ck::type_convert<XDataType>(1.0f)}(b_k_n_scale);
|
||||
if(config.verbosity > 0)
|
||||
{
|
||||
std::cout << "Init A = {1}" << std::endl;
|
||||
std::cout << "Init A scale = {0.5}" << std::endl;
|
||||
std::cout << "Init B = {1}" << std::endl;
|
||||
std::cout << "Init B scale = {2.0}" << std::endl;
|
||||
std::cout << "Init A scale = {2.0}" << std::endl;
|
||||
std::cout << "Init B = {0.5}" << std::endl;
|
||||
std::cout << "Init B scale = {1.0}" << std::endl;
|
||||
std::cout << "Expect C = {K}" << std::endl;
|
||||
}
|
||||
break;
|
||||
|
||||
case 1:
|
||||
ck::utils::FillUniformDistributionIntegerValue<ADataType>{-5.0f, 4.0f}(a_m_k);
|
||||
ck::utils::FillUniformDistributionIntegerValue<XDataType>{-1.0f, 1.0f}(a_m_k_scale);
|
||||
|
||||
ck::utils::FillUniformDistributionIntegerValue<BDataType>{-4.0f, 5.0f}(b_k_n);
|
||||
ck::utils::FillUniformDistributionIntegerValue<XDataType>{-1.0f, 1.0f}(b_k_n_scale);
|
||||
break;
|
||||
|
||||
case 2:
|
||||
a_m_k.GenerateTensorValue(GeneratorTensor_3<BDataType>{-2.0, 2.0});
|
||||
a_m_k_scale.GenerateTensorValue(GeneratorTensor_3<XDataType>{-1.0f, 1.0f});
|
||||
|
||||
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-2.0, 2.0});
|
||||
b_k_n_scale.GenerateTensorValue(GeneratorTensor_3<XDataType>{-1.0f, 1.0f});
|
||||
break;
|
||||
|
||||
default:
|
||||
if(config.verbosity > 0)
|
||||
{
|
||||
@@ -269,31 +310,31 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
if(config.verbosity > 0)
|
||||
std::cout << "Done." << std::endl;
|
||||
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto cde_element_op = CDEElementOp{};
|
||||
auto a_element_op = AElementOp{};
|
||||
auto b_element_op = BElementOp{};
|
||||
auto c_element_op = CElementOp{};
|
||||
|
||||
constexpr ck::index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
// do GEMM
|
||||
// run GEMM
|
||||
auto device_op = DeviceOpInstance{};
|
||||
auto invoker = device_op.MakeInvoker();
|
||||
auto argument = device_op.MakeArgument(a_device_buf.GetDeviceBuffer(),
|
||||
b_device_buf.GetDeviceBuffer(),
|
||||
std::array<const void*, NumDTensor>{},
|
||||
c_device_buf.GetDeviceBuffer(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
std::array<ck::index_t, NumDTensor>{},
|
||||
StrideC,
|
||||
a_scale_device_buf.GetDeviceBuffer(),
|
||||
b_scale_device_buf.GetDeviceBuffer(),
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
auto argument =
|
||||
device_op.MakeArgument(static_cast<ADataType*>(a_device_buf.GetDeviceBuffer()),
|
||||
static_cast<XDataType*>(a_scale_device_buf.GetDeviceBuffer()),
|
||||
static_cast<BDataType*>(b_device_buf.GetDeviceBuffer()),
|
||||
static_cast<XDataType*>(b_scale_device_buf.GetDeviceBuffer()),
|
||||
static_cast<CDataType*>(c_device_buf.GetDeviceBuffer()),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
Scale_Stride_AM,
|
||||
StrideB,
|
||||
Scale_Stride_BN,
|
||||
StrideC,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
|
||||
if(!device_op.IsSupportedArgument(argument))
|
||||
{
|
||||
@@ -303,7 +344,10 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
}
|
||||
|
||||
if(config.verbosity > 0)
|
||||
std::cout << "Computing GEMM on device..." << std::endl;
|
||||
{
|
||||
std::cout << "Computing GEMM on device..." << std::endl << std::endl;
|
||||
}
|
||||
|
||||
float ave_time =
|
||||
invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, config.verbosity, 20, 50});
|
||||
|
||||
@@ -321,7 +365,7 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
BDataType,
|
||||
CDataType,
|
||||
AccDataType,
|
||||
float,
|
||||
XDataType,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
@@ -347,12 +391,15 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
std::cout << "Comparing results..." << std::endl;
|
||||
}
|
||||
|
||||
if(config.init_method == 1)
|
||||
if(config.init_method == 0)
|
||||
{
|
||||
res_verified =
|
||||
res_verified && std::abs(static_cast<float>(K) - c_m_n_device_result(0, 0)) <= 0.0f;
|
||||
std::cout << "Expected vs Computed: " << 1.0f * K << " vs " << c_m_n_device_result(0, 0)
|
||||
<< ((res_verified) ? " (PASSED!)" : " (FAILED!)") << std::endl;
|
||||
auto expected = static_cast<float>(K);
|
||||
auto computed = type_convert<float>(c_m_n_device_result(1, 12));
|
||||
|
||||
res_verified = res_verified && std::abs(expected - computed) <= 0.0f;
|
||||
std::cout << "\nExpected vs Computed: " << expected << " vs " << computed
|
||||
<< ((res_verified) ? " (PASSED!)" : " (FAILED!)") << std::endl
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
res_verified = res_verified && ck::utils::check_err(c_m_n_device_result,
|
||||
@@ -360,7 +407,7 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
"Error: Incorrect results!");
|
||||
|
||||
if(config.verbosity > 0 && res_verified)
|
||||
std::cout << "Done." << std::endl;
|
||||
std::cout << "Verification Successful!" << std::endl;
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -370,17 +417,18 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
|
||||
|
||||
if(config.time_kernel)
|
||||
{
|
||||
std::size_t flop = std::size_t(2) * M * N * K + M * K + K * N; // GEMM + A scale + B scale
|
||||
std::size_t flop = std::size_t(2) * M * N * K +
|
||||
std::size_t(2) * M * N * K / ScaleBlockSize; // GEMM + A scale + B scale
|
||||
std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N +
|
||||
sizeof(CDataType) * M * N +
|
||||
sizeof(XDataType) * (M * K + K * N) / Scale_Block_K;
|
||||
sizeof(XDataType) * (M * K + K * N) / ScaleBlockSize;
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s" << std::endl;
|
||||
<< " GB/s, " << device_op.GetTypeString() << std::endl;
|
||||
}
|
||||
|
||||
return res_verified;
|
||||
@@ -393,13 +441,15 @@ template <typename ADataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename CElementWiseOp,
|
||||
typename AElementOp,
|
||||
typename BElementOp,
|
||||
typename CElementOp,
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
ck::index_t MXVectorSize>
|
||||
bool run_mx_gemm_example(int argc, char* argv[])
|
||||
{
|
||||
ProblemSize problem_size;
|
||||
ProblemSizeSplitK problem_size;
|
||||
ExecutionConfig config;
|
||||
|
||||
return parse_cmd_args(argc, argv, problem_size, config) &&
|
||||
@@ -410,7 +460,9 @@ bool run_mx_gemm_example(int argc, char* argv[])
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
CElementWiseOp,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
MXVectorSize>(problem_size, config);
|
||||
|
||||
@@ -5,23 +5,24 @@
|
||||
|
||||
using ADataType = ck::f8_t;
|
||||
using BDataType = ck::f8_t;
|
||||
#if 1
|
||||
// XXX: MX-native GEMM kernel will work with e8m0_bexp_t scale type
|
||||
using XDataType = float;
|
||||
#else
|
||||
using XDataType = ck::e8m0_bexp_t;
|
||||
#endif
|
||||
|
||||
// TODO: Enable e8m0_bexp_t and FP8 scale types
|
||||
using XDataType = ck::half_t;
|
||||
// using XDataType = ck::e8m0_bexp_t;
|
||||
|
||||
using CDataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = float;
|
||||
using CDataType = float;
|
||||
using CShuffleDataType = CDataType;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough; // elementwise transformation for A matrix
|
||||
using BElementOp = PassThrough; // elementwise transformation for B matrix
|
||||
using CElementOp = PassThrough; // elementwise transformation for C matrix
|
||||
|
||||
constexpr ck::index_t mx_vector_size = 128; // scaling block size
|
||||
constexpr ck::index_t mx_vector_size = 32; // scaling block size
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
@@ -32,6 +33,8 @@ int main(int argc, char* argv[])
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
AElementOp,
|
||||
BElementOp,
|
||||
CElementOp,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -181,6 +181,23 @@ struct BlockwiseGemmXdlops_pipeline_base
|
||||
|
||||
using Tuple4 = decltype(CalculateAThreadOriginDataIndex());
|
||||
|
||||
/**
|
||||
* @brief Constructor for BlockwiseGemmXdlops_pipeline_base.
|
||||
*
|
||||
* This constructor initializes the thread copy objects for matrices A and B.
|
||||
* It also performs several compile-time checks to ensure the correctness of the
|
||||
* matrix tile descriptors.
|
||||
*
|
||||
* @param a_origin The origin data index for matrix A.
|
||||
* @param b_origin The origin data index for matrix B.
|
||||
*
|
||||
* @note The constructor includes static assertions to ensure that:
|
||||
* - The matrix tile descriptors for A and B are known at compile-time.
|
||||
* - The number of threads in the thread block matches the product of MWaves, NWaves, and
|
||||
* WaveSize.
|
||||
* - The dimensions of the block are divisible by the product of the corresponding XDL and
|
||||
* repeat dimensions.
|
||||
*/
|
||||
__host__ __device__
|
||||
BlockwiseGemmXdlops_pipeline_base(Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
|
||||
Tuple4 b_origin = CalculateBThreadOriginDataIndex())
|
||||
|
||||
@@ -0,0 +1,69 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_mx.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSche,
|
||||
index_t ThreadBlockSize,
|
||||
index_t ScaleBlockSize,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename ComputeDataType, // TODO: remove this as in this pipeline ADataType and BDataType
|
||||
// must be used for compute
|
||||
typename AccDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack>
|
||||
constexpr auto BlockGemmMXPipeline_Selector()
|
||||
{
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
return BlockwiseGemmXdlops_pipeline_v1_mx<BlkGemmPipeSche,
|
||||
ThreadBlockSize,
|
||||
ScaleBlockSize,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "BlockGemmPipeline configuration is not available" << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,617 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Naive pipeline with lowest resource request per WGP
|
||||
// GlobalPrefetchStages: 1
|
||||
// LocalPreFillStages: 1
|
||||
// LocalPreFetchStages: 0
|
||||
// LocalSharedMemoryBuffer: 1
|
||||
|
||||
template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
|
||||
index_t ThreadBlockSize,
|
||||
index_t ScaleBlockSize,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename ComputeDataType,
|
||||
typename AccDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat, // MXdlPerWave
|
||||
index_t NRepeat, // NXdlPerWave
|
||||
index_t KPack>
|
||||
struct BlockwiseGemmXdlops_pipeline_v1_mx
|
||||
{
|
||||
};
|
||||
|
||||
template <index_t ThreadBlockSize,
|
||||
index_t ScaleBlockSize,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename ComputeDataType,
|
||||
typename AccDataType,
|
||||
typename ATileDesc,
|
||||
typename BTileDesc,
|
||||
typename AMmaTileDesc,
|
||||
typename BMmaTileDesc,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MRepeat, // MXdlPerWave
|
||||
index_t NRepeat, // NXdlPerWave
|
||||
index_t KPack>
|
||||
struct BlockwiseGemmXdlops_pipeline_v1_mx<BlockGemmPipelineScheduler::Intrawave,
|
||||
ThreadBlockSize,
|
||||
ScaleBlockSize,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
: BlockwiseGemmXdlops_pipeline_base<ThreadBlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>
|
||||
|
||||
{
|
||||
using Base = BlockwiseGemmXdlops_pipeline_base<ThreadBlockSize,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ComputeDataType,
|
||||
AccDataType,
|
||||
ATileDesc,
|
||||
BTileDesc,
|
||||
AMmaTileDesc,
|
||||
BMmaTileDesc,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::KRepeat;
|
||||
using Base::MWaves;
|
||||
using Base::NWaves;
|
||||
using Base::WaveSize;
|
||||
using Base::xdlops_gemm;
|
||||
|
||||
using Base::CalculateCThreadOriginDataIndex;
|
||||
using Base::CalculateCThreadOriginDataIndex8D;
|
||||
using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::GetCThreadBuffer;
|
||||
using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
|
||||
using Base::GetWaveIdx;
|
||||
using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
|
||||
|
||||
using Base::a_block_desc_m0_m1_m2_k;
|
||||
using Base::b_block_desc_n0_n1_n2_k;
|
||||
|
||||
using Base::AMmaKStride;
|
||||
using Base::BMmaKStride;
|
||||
|
||||
using Tuple4 = typename Base::Tuple4;
|
||||
|
||||
static constexpr index_t PrefetchStages = 1;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
|
||||
static constexpr auto ScalesPerKBlockSize =
|
||||
KPerBlock / ScaleBlockSize; // How many mx-vectors per K block size
|
||||
|
||||
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
__host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
ignore = num_loop;
|
||||
return TailNumber::Full;
|
||||
}
|
||||
|
||||
__device__ static auto CalculateAThreadOriginDataIndex()
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
const auto waveId_m = wave_idx[I0];
|
||||
|
||||
const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
|
||||
|
||||
return make_tuple(0, waveId_m, xdlops_a_idx[I1], xdlops_gemm.KPerXdlops * xdlops_a_idx[I0]);
|
||||
}
|
||||
|
||||
__device__ static auto CalculateBThreadOriginDataIndex()
|
||||
{
|
||||
const auto wave_idx = GetWaveIdx();
|
||||
|
||||
const auto waveId_n = wave_idx[I1];
|
||||
|
||||
const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
|
||||
|
||||
return make_tuple(0, waveId_n, xdlops_b_idx[I1], xdlops_gemm.KPerXdlops * xdlops_b_idx[I0]);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Constructor for BlockwiseGemmXdlops_pipeline_v1_mx.
|
||||
*
|
||||
* The primary purpose of this constructor is to modify default initialization of the base class
|
||||
* with the origin data index suitable for microscaling.
|
||||
*
|
||||
* @param a_origin The origin data index for matrix A.
|
||||
* @param b_origin The origin data index for matrix B.
|
||||
*
|
||||
*/
|
||||
__host__ __device__
|
||||
BlockwiseGemmXdlops_pipeline_v1_mx(Tuple4 a_origin = CalculateAThreadOriginDataIndex(),
|
||||
Tuple4 b_origin = CalculateBThreadOriginDataIndex())
|
||||
: Base(a_origin, b_origin)
|
||||
{
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
TailNumber TailNum,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename CThreadBuffer,
|
||||
typename AScaleGridBuffer,
|
||||
typename AScaleGridDesc,
|
||||
typename AScaleThreadTransfer,
|
||||
typename BScaleGridBuffer,
|
||||
typename BScaleGridDesc,
|
||||
typename BScaleThreadTransfer>
|
||||
__device__ void Run(
|
||||
// ABlockCopy
|
||||
const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
// BBlockCopy
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
// CThread
|
||||
CThreadBuffer& c_thread_buf,
|
||||
// A and B scales
|
||||
const AScaleGridDesc& a_scale_grid_desc,
|
||||
AScaleThreadTransfer& a_scale_thread_copy,
|
||||
const AScaleGridBuffer& a_scale_grid_buf,
|
||||
const BScaleGridDesc& b_scale_grid_desc,
|
||||
BScaleThreadTransfer& b_scale_thread_copy,
|
||||
const BScaleGridBuffer& b_scale_grid_buf,
|
||||
index_t num_loop) const
|
||||
{
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
|
||||
a_scale_thread_desc.GetElementSpaceSize());
|
||||
|
||||
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
|
||||
b_scale_thread_desc.GetElementSpaceSize());
|
||||
|
||||
// Global prefetch 1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
static_assert(xdlops_gemm.mfma_instr.num_groups_per_blk *
|
||||
xdlops_gemm.mfma_instr.group_size ==
|
||||
xdlops_gemm.GetRegSizePerXdlops(),
|
||||
"Assume num_regs_per_blk == num_groups_per_blk * group_size");
|
||||
|
||||
// Prefetch a_scales
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, xdlops_gemm.mfma_instr.num_groups_per_blk, 1>{}([&](auto g) {
|
||||
auto a_scale_thread_buf_group =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
|
||||
a_scale_thread_desc_group.GetElementSpaceSize());
|
||||
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc_group,
|
||||
make_tuple(I0, I0),
|
||||
a_scale_thread_buf_group);
|
||||
|
||||
static_for<0, xdlops_gemm.mfma_instr.group_size, 1>{}([&](auto i) {
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, g, i));
|
||||
a_scale_thread_buf(Number<a_scale_offset>{}) =
|
||||
a_scale_thread_buf_group[Number<i>{}];
|
||||
});
|
||||
// go to the next group
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc,
|
||||
make_multi_index(2 * xdlops_gemm.mfma_instr.group_size, 0));
|
||||
}); // g
|
||||
|
||||
// restore row id and advance to the next scale
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc,
|
||||
make_multi_index(-2 * xdlops_gemm.mfma_instr.group_size *
|
||||
xdlops_gemm.mfma_instr.num_groups_per_blk,
|
||||
1));
|
||||
}); // k0
|
||||
|
||||
// restore column id and advance to the next set of rows
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize));
|
||||
}); // m0
|
||||
|
||||
// restore row id and advance to the next set of scales
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
make_multi_index(-MPerBlock, ScalesPerKBlockSize));
|
||||
|
||||
// Prefetch b_scales
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(n0, I0),
|
||||
b_scale_thread_buf);
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
|
||||
make_multi_index(NWaves * NPerXDL, 0));
|
||||
});
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
|
||||
make_multi_index(-NPerBlock, ScalesPerKBlockSize));
|
||||
|
||||
// Local prefill 1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
auto c_thread_buf_per_scale = remove_cvref_t<decltype(c_thread_buf)>();
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
// loop over k with the step KPerBlock
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
// -------------------------------------------------------------------------------------------
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
constexpr auto a_k_step = k * AMmaKStride * KPack / xdlops_gemm.K1PerXdlops;
|
||||
constexpr auto b_k_step = k * BMmaKStride * KPack / xdlops_gemm.K1PerXdlops;
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<a_k_step>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
|
||||
make_tuple(n0, I0, I0, Number<b_k_step>{}),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(n0, I0, k, I0),
|
||||
b_thread_buf);
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
c_thread_buf_per_scale.Clear();
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
// MFMA accumulation
|
||||
// m = 1:MPerXDL
|
||||
// n = 1:NPerXDL
|
||||
// k = 1:KPack
|
||||
// c(m,n) += a(m,k)*b(k,n)
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(I0));
|
||||
|
||||
// one scale per k0
|
||||
constexpr index_t b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0));
|
||||
|
||||
static_for<0, xdlops_gemm.mfma_instr.num_groups_per_blk, 1>{}(
|
||||
[&](auto g) {
|
||||
static_for<0, xdlops_gemm.mfma_instr.group_size, 1>{}(
|
||||
[&](auto r) {
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(
|
||||
make_tuple(m0, k0, g, r));
|
||||
|
||||
constexpr auto reg_offset =
|
||||
g * xdlops_gemm.mfma_instr.group_size + r;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, n0, reg_offset));
|
||||
|
||||
c_thread_buf(Number<c_offset>{}) +=
|
||||
c_thread_buf_per_scale[Number<reg_offset>{}] *
|
||||
type_convert<AccDataType>(
|
||||
b_scale_thread_buf[Number<b_scale_offset>{}]) *
|
||||
type_convert<AccDataType>(
|
||||
a_scale_thread_buf[Number<a_scale_offset>{}]);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, xdlops_gemm.mfma_instr.num_groups_per_blk, 1>{}([&](auto g) {
|
||||
auto a_scale_thread_buf_group =
|
||||
make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
|
||||
a_scale_thread_desc_group.GetElementSpaceSize());
|
||||
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc_group,
|
||||
make_tuple(I0, I0),
|
||||
a_scale_thread_buf_group);
|
||||
|
||||
static_for<0, xdlops_gemm.mfma_instr.group_size, 1>{}([&](auto r) {
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, g, r));
|
||||
a_scale_thread_buf(Number<a_scale_offset>{}) =
|
||||
a_scale_thread_buf_group[Number<r>{}];
|
||||
});
|
||||
// go to the next group
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc,
|
||||
make_multi_index(2 * xdlops_gemm.mfma_instr.group_size, 0));
|
||||
}); // g
|
||||
|
||||
// restore row id and advance to the next scale
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc,
|
||||
make_multi_index(-2 * xdlops_gemm.mfma_instr.group_size *
|
||||
xdlops_gemm.mfma_instr.num_groups_per_blk,
|
||||
1));
|
||||
}); // k0
|
||||
|
||||
// restore column id and advance to the next set of rows
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc,
|
||||
make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize));
|
||||
}); // m0
|
||||
|
||||
// restore row id and advance to the next set of scales
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, make_multi_index(-MPerBlock, ScalesPerKBlockSize));
|
||||
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(n0, I0),
|
||||
b_scale_thread_buf);
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
|
||||
make_multi_index(NWaves * NPerXDL, 0));
|
||||
});
|
||||
// NWaves * NPerXDL * NRepeat == NPerBlock
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(
|
||||
b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize));
|
||||
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
i += 1;
|
||||
|
||||
} while(i < (num_loop - 1));
|
||||
}
|
||||
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Full)
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
constexpr auto a_k_step = k * AMmaKStride * KPack / xdlops_gemm.K1PerXdlops;
|
||||
constexpr auto b_k_step = k * BMmaKStride * KPack / xdlops_gemm.K1PerXdlops;
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<a_k_step>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
|
||||
make_tuple(n0, I0, I0, Number<b_k_step>{}),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(n0, I0, k, I0),
|
||||
b_thread_buf);
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
c_thread_buf_per_scale.Clear();
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(I0));
|
||||
|
||||
// one scale per k0
|
||||
constexpr index_t b_scale_offset =
|
||||
b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0));
|
||||
|
||||
static_for<0, xdlops_gemm.mfma_instr.num_groups_per_blk, 1>{}([&](auto g) {
|
||||
static_for<0, xdlops_gemm.mfma_instr.group_size, 1>{}([&](auto r) {
|
||||
constexpr index_t a_scale_offset =
|
||||
a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, g, r));
|
||||
|
||||
constexpr auto reg_offset =
|
||||
g * xdlops_gemm.mfma_instr.group_size + r;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, reg_offset));
|
||||
|
||||
c_thread_buf(Number<c_offset>{}) +=
|
||||
c_thread_buf_per_scale[Number<reg_offset>{}] *
|
||||
type_convert<AccDataType>(
|
||||
b_scale_thread_buf[Number<b_scale_offset>{}]) *
|
||||
type_convert<AccDataType>(
|
||||
a_scale_thread_buf[Number<a_scale_offset>{}]);
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: make this field protected when a_scale_thread_copy_ is moved here
|
||||
static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<MRepeat>{},
|
||||
Number<KRepeat>{},
|
||||
Number<xdlops_gemm.mfma_instr.num_groups_per_blk>{},
|
||||
Number<xdlops_gemm.mfma_instr.group_size>{}));
|
||||
|
||||
// Is used to copy data from a_scale_grid to a_scale_thread
|
||||
static constexpr auto a_scale_thread_desc_group = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<xdlops_gemm.mfma_instr.group_size>{}, Number<1>{}));
|
||||
|
||||
// TODO: make this field protected when b_scale_thread_copy_ is moved here
|
||||
static constexpr auto b_scale_thread_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<NRepeat>{}, Number<KRepeat>{}));
|
||||
|
||||
protected:
|
||||
using Base::a_thread_copy_;
|
||||
using Base::a_thread_desc_;
|
||||
using Base::b_thread_copy_;
|
||||
using Base::b_thread_desc_;
|
||||
using Base::c_thread_desc_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
50
include/ck/tensor_operation/gpu/device/device_gemm_mx.hpp
Normal file
50
include/ck/tensor_operation/gpu/device/device_gemm_mx.hpp
Normal file
@@ -0,0 +1,50 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename CDataType,
|
||||
index_t ScaleBlockSize,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation>
|
||||
struct DeviceGemmMX : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_a_scale,
|
||||
const void* p_b,
|
||||
const void* p_b_scale,
|
||||
void* p_c,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideAScale,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideBScale,
|
||||
ck::index_t StrideC,
|
||||
ck::index_t KBatch,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,877 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
|
||||
#include "ck/host_utility/flush_cache.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_mx.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
/**
|
||||
* \brief WIP: Implements XDL CShuffle V3 GEMM for microscale-compliant data types
|
||||
*
|
||||
* This class is a work-in-progress implementation of the XDL CShuffle V3 GEMM for
|
||||
* microscale-compliant data types.
|
||||
*
|
||||
* Assumptions:
|
||||
* - A and B data types are compliant with the OCP Microscaling Formats (MX) Specification
|
||||
* - Each scale applies to ScaleBlockSize elements in K direction
|
||||
* - A scale matrix is row-major
|
||||
* - B scale matrix is column-major
|
||||
* - Scale data types must have get_exponent_value() specialization, whereas lowest 8 bits of the
|
||||
* exponent will be interpreted as conventional biased Float32 exponent (E8M0)
|
||||
*
|
||||
* Tunable parameters.
|
||||
* The CK instance includes a series of tunable template parameters to control the parallel
|
||||
* granularity of the workload to achieve load balancing on different hardware platforms. These
|
||||
* parameters include Block Size, M/N/K Per Block, M/N per XDL, AK1, BK1, etc.
|
||||
* - Block Size determines the number of threads in the thread block.
|
||||
* - M/N/K Per Block determines the size of tile that each thread block is responsible for
|
||||
* calculating.
|
||||
* - M/N Per XDL refers to M/N size for Instinct accelerator Matrix Fused Multiply Add (MFMA)
|
||||
* instructions operating on a per-wavefront basis.
|
||||
* - A/B K1 is related to the data type. It can be any value ranging from 1 to K Per Block. To
|
||||
* achieve the optimal load/store performance, 128bit per load is suggested. In addition, the A/B
|
||||
* loading parameters must be changed accordingly to match the A/B K1 value; otherwise, it will
|
||||
* result in compilation errors.
|
||||
*
|
||||
* Conditions for achieving computational load balancing on different hardware platforms can vary.
|
||||
*
|
||||
* Serialized version of the algorithm:
|
||||
* \code
|
||||
* // E = A * B + C
|
||||
* // Loop over E[MPerBlock,NPerBlock] tiles
|
||||
* for(int mb = 0; mb < M; mb += MPerBlock){
|
||||
* for(int nb = 0; nb < N; nb += NPerBlock){
|
||||
* // initialize E[MPerBlock,NPerBlock] tile
|
||||
* for(int mt = mb; mt < mb + MPerBlock; mt++){
|
||||
* for(int nt = nb; nt < nb + NPerBlock; nt++){
|
||||
* E[mt,nt] = C[mt,nt];
|
||||
* }
|
||||
* }
|
||||
*
|
||||
* // multiply-accumulate per tile
|
||||
* for(int kb = 0; kb < K; kb += KPerBlock){
|
||||
* for(int m0 = mb; m0 < mb + MPerBlock; m0 += MWaves * MPerXDL){
|
||||
* for(int n0 = nb; n0 < nb + NPerBlock; n0 += NWaves * NPerXDL){
|
||||
* for(int mw = m0; mw < m0 + MWaves * MPerXDL; mw += MPerXDL){
|
||||
* for(int nw = n0; nw < n0 + NWaves * NPerXDL; nw += NPerXDL){
|
||||
* for(int k0 = kb; k0 < kb + KPerBlock; k0 += mfma.num_input_blks*KPack){
|
||||
* // MFMA accumulation for multirate instructions
|
||||
* for(int k_pack = k0; k_pack < k0 + mfma.num_input_blks*KPack; k_pack += KPack){
|
||||
* for(int k_mfma = k_pack; k_mfma < k_pack + KPack; k_mfma += mfma.k_per_blk){
|
||||
* // MFMA instruction
|
||||
* for(int m = mw; m < mw + MPerXDL; m++){
|
||||
* for(int n = nw; n < nw + NPerXDL; n++){
|
||||
* for(int k = k_mfma; k < k_mfma + mfma.k_per_blk; k++){
|
||||
* E[m,n] += A[m,k] * B[k,n];
|
||||
* }
|
||||
* }
|
||||
* }
|
||||
* }
|
||||
* }
|
||||
* }
|
||||
* }
|
||||
* }
|
||||
* }
|
||||
* }
|
||||
* }
|
||||
* }
|
||||
* }
|
||||
* \endcode
|
||||
*
|
||||
*/
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename CDataType,
|
||||
typename GemmAccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t ScaleBlockSize, // Scaling block size
|
||||
index_t BlockSize, // Thread block size
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MXdlPerWave,
|
||||
index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
typename ComputeTypeA =
|
||||
ADataType, // XXX: These should always be the same as ADataType and BDataType
|
||||
typename ComputeTypeB =
|
||||
BDataType // TODO: Hardcode them and remove from the list of template parameters
|
||||
>
|
||||
struct DeviceGemmMX_Xdl_CShuffleV3 : public DeviceGemmMX<ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
CDataType,
|
||||
ScaleBlockSize,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
{
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemmMX_xdl_cshuffle_v3<
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
GemmAccDataType,
|
||||
CShuffleDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
GemmSpec,
|
||||
ScaleBlockSize,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>;
|
||||
|
||||
using Argument = typename GridwiseGemm::Argument;
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
arg.Print();
|
||||
GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print();
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
|
||||
}
|
||||
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
index_t k_grain = arg.KBatch * KPerBlock;
|
||||
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
if(stream_config.flush_cache)
|
||||
{
|
||||
Argument arg_ = arg;
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
|
||||
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
|
||||
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
|
||||
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
|
||||
|
||||
auto size_a_buffer =
|
||||
a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType);
|
||||
auto size_b_buffer =
|
||||
b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType);
|
||||
|
||||
ck::utility::RotatingMemWrapper<Argument> rotating_mem(
|
||||
arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck::utility::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(arg_.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
|
||||
0,
|
||||
arg_.M * arg_.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
};
|
||||
|
||||
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
|
||||
stream_config,
|
||||
run_flush_cache,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg_);
|
||||
}
|
||||
else
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
|
||||
0,
|
||||
arg.M * arg.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: Check if this is the right algorithm for minimum_occupancy
|
||||
constexpr index_t minimum_occupancy =
|
||||
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave
|
||||
? (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 &&
|
||||
MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) <= 128 * 128 * 64 * 2)
|
||||
? 2
|
||||
: 1
|
||||
: 2;
|
||||
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
// Tail number always full
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
// Tail number could be One to Seven
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::One>;
|
||||
Run(kernel);
|
||||
}
|
||||
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Full)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Full>;
|
||||
Run(kernel);
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Two>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Three)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Three>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Four)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Four>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Five)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Five>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Six>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Seven)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Seven>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::One>;
|
||||
Run(kernel);
|
||||
}
|
||||
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Full)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Full>;
|
||||
Run(kernel);
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Two>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Three)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Three>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Four)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Four>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Five)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Five>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Six>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
|
||||
TailNumber::Seven)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Seven>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Tail number could be Odd or Even
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3_2lds<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Tail number always 1
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::AtomicAdd,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
static_assert((is_same_v<ADataType, f8_t> || is_same_v<ADataType, bf8_t> ||
|
||||
is_same_v<ADataType, f6_t> || is_same_v<ADataType, bf6_t> ||
|
||||
is_same_v<ADataType, f4_t>)&&(is_same_v<BDataType, f8_t> ||
|
||||
is_same_v<BDataType, bf8_t> ||
|
||||
is_same_v<BDataType, f6_t> ||
|
||||
is_same_v<BDataType, bf6_t> ||
|
||||
is_same_v<BDataType, f4_t>),
|
||||
"Only microscaling formats are supported for ADataType and BDataType");
|
||||
|
||||
static_assert(ScaleBlockSize == 32, "Only ScaleBlockSize 32 is supported");
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_xdl_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding ||
|
||||
GemmSpec == GemmSpecialization::KPadding))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const ADataType* p_a,
|
||||
const AScaleDataType* p_a_scale,
|
||||
const BDataType* p_b,
|
||||
const BScaleDataType* p_b_scale,
|
||||
CDataType* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideScaleA,
|
||||
index_t StrideB,
|
||||
index_t StrideScaleB,
|
||||
index_t StrideC,
|
||||
index_t KBatch,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_a_scale,
|
||||
p_b,
|
||||
p_b_scale,
|
||||
p_c,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideScaleA,
|
||||
StrideB,
|
||||
StrideScaleB,
|
||||
StrideC,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_a_scale,
|
||||
const void* p_b,
|
||||
const void* p_b_scale,
|
||||
void* p_c,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideScaleA,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideScaleB,
|
||||
ck::index_t StrideC,
|
||||
ck::index_t KBatch,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const AScaleDataType*>(p_a_scale),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<const BScaleDataType*>(p_b_scale),
|
||||
static_cast<CDataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideScaleA,
|
||||
StrideB,
|
||||
StrideScaleB,
|
||||
StrideC,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
|
||||
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
|
||||
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
|
||||
{BlockGemmPipelineVersion::v1, "v1"},
|
||||
{BlockGemmPipelineVersion::v2, "v2"},
|
||||
{BlockGemmPipelineVersion::v3, "v3"},
|
||||
{BlockGemmPipelineVersion::v4, "v4"},
|
||||
{BlockGemmPipelineVersion::v5, "v5"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGemmMX_Xdl_CShuffleV3"
|
||||
<< "<"
|
||||
<< getGemmSpecializationString(GemmSpec) << ", "
|
||||
<< std::string(ALayout::name)[0]
|
||||
<< std::string(BLayout::name)[0]
|
||||
<< std::string(CLayout::name)[0]
|
||||
<< ">"
|
||||
<< " BlkSize: "
|
||||
<< BlockSize << ", "
|
||||
<< "BlkTile: "
|
||||
<< MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
|
||||
<< "WaveTile: "
|
||||
<< MPerXDL<<"x"<<NPerXDL << ", "
|
||||
<< "WaveMap: "
|
||||
<< MXdlPerWave<<"x" << NXdlPerWave<<", "
|
||||
<< "VmemReadVec: "
|
||||
<< ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
|
||||
<< "BlkGemmPipelineScheduler: "
|
||||
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
|
||||
<< "BlkGemmPipelineVersion: "
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
|
||||
<< "BlkGemmPipelinePrefetchStages: "
|
||||
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", "
|
||||
<< "Kpack: "
|
||||
<< GridwiseGemm::BlockwiseGemmPipe::AMmaKStride << ", "
|
||||
<< "ScaleBlockSize: "
|
||||
<< ScaleBlockSize;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
REGISTER_EXTRA_PRINTING_METHODS
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
File diff suppressed because it is too large
Load Diff
@@ -189,15 +189,36 @@ struct ThreadwiseTensorSliceTransfer_v1r3
|
||||
const ElementwiseOperation element_op_;
|
||||
}; // namespace ThreadwiseTensorSliceTransfer_v1r3
|
||||
|
||||
// Assume:
|
||||
// 1. src:
|
||||
// 1. SrcDesc is not known at compile-time
|
||||
// 2. SrcBuffer is DynamicBuffer
|
||||
// 3. src_slice_origin_idx is not known at compile-time
|
||||
// 2. dst:
|
||||
// 1. DstDesc is known at compile-time
|
||||
// 2. DstBuffer is StaticBuffer
|
||||
// 3. dst_slice_origin_idx is known at compile-time
|
||||
/**
|
||||
* @brief Helper structure that facilitates transfer of source (grid) data to destination threads.
|
||||
*
|
||||
* @details The following assumptions are made:
|
||||
* - For Source (Grid) Data:
|
||||
* 1. The source tensor descriptor SrcDesc is not known at compile-time.
|
||||
* 2. The source buffer is a dynamic buffer.
|
||||
* 3. The source slice origin index src_slice_origin_idx is not known at compile-time.
|
||||
* - For Destination (Thread) Data:
|
||||
* 1. The destination tensor descriptor DstDesc is known at compile-time.
|
||||
* 2. The destination buffer dst_buf is a static buffer.
|
||||
* 3. The destination slice origin index dst_slice_origin_idx is known at compile-time.
|
||||
*
|
||||
* @tparam SrcData The data type of the source tensor.
|
||||
* @tparam DstData The data type of the destination tensor.
|
||||
* @tparam SrcDesc The descriptor type of the source tensor.
|
||||
* @tparam DstDesc The descriptor type of the destination tensor.
|
||||
* @tparam SliceLengths The lengths of the slice to be transferred.
|
||||
* @tparam DimAccessOrder The order of dimension access for the space-filling curve.
|
||||
* @tparam SrcVectorDim The dimension along which vectorized access is performed in the source
|
||||
* tensor.
|
||||
* @tparam SrcScalarPerVector The number of scalar elements per vector in the source tensor.
|
||||
* @tparam SrcScalarStrideInVector The stride of scalar elements within a vector in the source
|
||||
* tensor.
|
||||
* @tparam SrcResetCoordinateAfterRun controls whether source coordinate is restored after each Run
|
||||
* or rolled back one step in MoveSrcSliceWindow
|
||||
* @tparam InvalidElementAsNaN Whether to fill invalid elements with NaN (only applicable for
|
||||
* floating-point types).
|
||||
*
|
||||
*/
|
||||
template <typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
|
||||
@@ -793,7 +793,7 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x64f8f6f4>
|
||||
static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
|
||||
static constexpr index_t m_per_blk = 32; // from the instruction
|
||||
static constexpr index_t n_per_blk = 32; // from the instruction
|
||||
static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 64 / num_input_blks
|
||||
static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks
|
||||
static constexpr bool is_k_reduction = true; // ???
|
||||
// clang-format on
|
||||
|
||||
@@ -817,7 +817,7 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x128f8f6f4>
|
||||
static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
|
||||
static constexpr index_t m_per_blk = 16; // from the instruction
|
||||
static constexpr index_t n_per_blk = 16; // from the instruction
|
||||
static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 128 / num_input_blks
|
||||
static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks
|
||||
static constexpr bool is_k_reduction = true; // ???
|
||||
// clang-format on
|
||||
|
||||
@@ -841,7 +841,7 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_32x32x64f8f6f4>
|
||||
static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
|
||||
static constexpr index_t m_per_blk = 32; // from the instruction
|
||||
static constexpr index_t n_per_blk = 32; // from the instruction
|
||||
static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 64 / num_input_blks
|
||||
static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks
|
||||
static constexpr bool is_k_reduction = true; // ???
|
||||
// clang-format on
|
||||
|
||||
@@ -870,7 +870,7 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>
|
||||
static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
|
||||
static constexpr index_t m_per_blk = 16; // from the instruction
|
||||
static constexpr index_t n_per_blk = 16; // from the instruction
|
||||
static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 128 / num_input_blks
|
||||
static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks
|
||||
static constexpr bool is_k_reduction = true; // ???
|
||||
// clang-format on
|
||||
|
||||
|
||||
Reference in New Issue
Block a user