From bbdd7f6d57484f77d202138a0a42f50a1b4db012 Mon Sep 17 00:00:00 2001 From: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> Date: Mon, 24 Mar 2025 15:41:07 -0600 Subject: [PATCH] Introduce MX GEMM for FP8 data type (#2000) [ROCm/composable_kernel commit: 6660dc6b8eb679ca7654e36fa01cf3b71159bfa5] --- cmake/ClangTidy.cmake | 2 +- .../67_gemm_microscaling/gemm_mx_common.hpp | 274 +- example/67_gemm_microscaling/gemm_mx_fp8.cpp | 21 +- .../blockwise_gemm_pipeline_xdlops_base.hpp | 19 +- ...kwise_gemm_pipeline_xdlops_mx_selector.hpp | 69 + .../blockwise_gemm_pipeline_xdlops_v1_mx.hpp | 617 +++++ .../gpu/device/device_gemm_mx.hpp | 50 + .../impl/device_gemm_xdl_cshuffle_v3_mx.hpp | 877 +++++++ .../grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp | 2288 +++++++++++++++++ .../threadwise_tensor_slice_transfer.hpp | 39 +- .../tensor_operation/gpu/warp/xdlops_gemm.hpp | 8 +- 11 files changed, 4129 insertions(+), 135 deletions(-) create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_selector.hpp create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_mx.hpp create mode 100644 include/ck/tensor_operation/gpu/device/device_gemm_mx.hpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp diff --git a/cmake/ClangTidy.cmake b/cmake/ClangTidy.cmake index cf77991a64..d0d30d669a 100644 --- a/cmake/ClangTidy.cmake +++ b/cmake/ClangTidy.cmake @@ -144,7 +144,7 @@ function(clang_tidy_check TARGET) # COMMAND ${CLANG_TIDY_COMMAND} $, > foreach(SOURCE ${SOURCES}) if((NOT "${SOURCE}" MATCHES "(h|hpp|hxx)$") AND (NOT "${SOURCE}" MATCHES "TARGET_OBJECTS")) - string(MAKE_C_IDENTIFIER "${SOURCE}" tidy_file) + string(MD5 tidy_file "${SOURCE}") set(tidy_target tidy-target-${TARGET}-${tidy_file}) add_custom_target(${tidy_target} # for some targets clang-tidy not able to get information from .clang-tidy diff --git a/example/67_gemm_microscaling/gemm_mx_common.hpp b/example/67_gemm_microscaling/gemm_mx_common.hpp index 30f03cb53b..b8ff765174 100644 --- a/example/67_gemm_microscaling/gemm_mx_common.hpp +++ b/example/67_gemm_microscaling/gemm_mx_common.hpp @@ -9,20 +9,17 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" #include "ck/utility/blkgemmpipe_scheduler.hpp" #include "ck/utility/data_type.hpp" #include "ck/utility/sequence.hpp" - #include "ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp" - #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/fill.hpp" #include "ck/library/utility/host_tensor.hpp" -using ScaleDataType = ck::e8m0_bexp_t; - template using S = ck::Sequence; @@ -31,6 +28,8 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ck::type_convert; + struct ExecutionConfig final { int do_verification = 1; // (0=no, 1=CPU) @@ -39,8 +38,9 @@ struct ExecutionConfig final int verbosity = 0; // (0=no info, 1=verbose info) }; -struct ProblemSize final +struct ProblemSizeSplitK final { + ck::index_t M = 3840; ck::index_t N = 4096; ck::index_t K = 4096; @@ -48,9 +48,14 @@ struct ProblemSize final ck::index_t StrideA = -1; ck::index_t StrideB = -1; ck::index_t StrideC = -1; + + ck::index_t KBatch = 1; }; -bool parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, ExecutionConfig& config) +bool parse_cmd_args(int argc, + char* argv[], + ProblemSizeSplitK& problem_size, + ExecutionConfig& config) { if(argc == 1) { @@ -63,7 +68,7 @@ bool parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, Execution config.time_kernel = std::stoi(argv[3]); config.verbosity = std::stoi(argv[4]); } - else if(argc == 11) + else if(argc >= 11) { config.do_verification = std::stoi(argv[1]); config.init_method = std::stoi(argv[2]); @@ -77,6 +82,11 @@ bool parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, Execution problem_size.StrideA = std::stoi(argv[8]); problem_size.StrideB = std::stoi(argv[9]); problem_size.StrideC = std::stoi(argv[10]); + + if(argc >= 12) + { + problem_size.KBatch = std::stoi(argv[11]); + } } else { @@ -85,7 +95,8 @@ bool parse_cmd_args(int argc, char* argv[], ProblemSize& problem_size, Execution << std::endl << "arg3: time kernel (0=no, 1=yes)" << std::endl << "arg4: verbosity (0=no info, 1=verbose info)" << std::endl - << "arg5 to 10: M (16x), N(16x), K(16x), StrideA, StrideB, StrideC" << std::endl; + << "arg5 to 10: M(256x), N(128x), K(32x), StrideA, StrideB, StrideC" << std::endl + << "arg11: KBatch" << std::endl; return false; } @@ -99,56 +110,70 @@ template -bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& config) { - using ELayout = CLayout; - using DsLayout = ck::Tuple<>; - using DsDataType = ck::Tuple<>; - using AElementOp = PassThrough; - using BElementOp = PassThrough; - using CDEElementOp = CElementWiseOp; - static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave; - static constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3; + static constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v1; -#if 1 - // XXX: These parameters should not exist in MX-native GEMM kernel - static constexpr ck::index_t Scale_Block_M = 128; - static constexpr ck::index_t Scale_Block_N = 128; -#endif - static constexpr ck::index_t Scale_Block_K = MXVectorSize; + static constexpr ck::index_t ScaleBlockSize = MXVectorSize; - // XXX: DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 is not designed to utilize MX-specific MFMA - // instructions. - // - // XXX: DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 is not designed to utilize device-optimized - // scaled type convert functions. - // - // XXX: In DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3, KPerBlock is expected to be equal to - // ScaleBlockK (aka MXVectorSize). - // Additionally, the following is also expected: - // static_assert(ScaleBlockM % MPerBlock == 0); - // static_assert(ScaleBlockN % NPerBlock == 0); - // In MX-native GEMM kernel these requirements should be relaxed. - // - // XXX: It appears, by default we are using mfma_f32_16x16x4xf32 - // MfmaSelector::selected_mfma.k_per_blk = - // MfmaSelector::selected_mfma.k_per_blk = mfma_f32_16x16x4xf32 - // XXX: GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 assumes scale type is float - - // clang-format off - using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 - // ######| ALayout| BLayout| DsLayout| CLayout| ADataType| AScale| BDataType| BScale| DsDataType| CDataType| GemmAcc| CShuffleDataType|AElementwise|BElementwise| CElementwise| GemmSpec|Block| ScaleBlockM| ScaleBlockN| ScaleBlockK| M| N| K| AK1| BK1| M| N|MXdl|NXdl|ABlockTransfer|ABlockTransfer|ABlockTransfer|ABlockTransfer|ABlockTransfer|ABlockTransfer| ABlock|BBlockTransfer|BBlockTransfer|BBlockTransfer|BBlockTransfer|BBlockTransfer|BBlockTransfer| BBlock| CShuffle| CShuffle|CShuffleBlockTransfer|CDEShuffleBlockTransfer| BlkGemm| BlkGemm|ComputeTypeA|ComputeTypeB|LDSTypeA|LDSTypeB| - // ######| | | | | | DataType| | DataType| | | DataType| | Operation| Operation| Operation| | Size| | | | Per| Per| Per| | | Per| Per| Per| Per| ThreadCluster| ThreadCluster|SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar|LdsExtraM| ThreadCluster| ThreadCluster|SrcAccessOrder| SrcVector| SrcScalar| DstScalar|LdsExtraN| MXdl| NXdl| ClusterLengths| Scalar| PipeSched| PipelineVer| | | | | - // ######| | | | | | | | | | | | | | | | | | | | |Block|Block| Block| | | XDL| XDL|Wave|Wave| Lengths| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths| ArrangeOrder| | Dim| PerVector| PerVector_BK1| | PerWave| PerWave| MBlock_MPerBlock| PerVectors| | | | | | | - // ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | AK0_M_AK1| | | | | | | BK0_N_BK1| | | | | |PerShuffle|PerShuffle| NBlock_NPerBlock| | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, XDataType, BDataType, XDataType, DsDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, Scale_Block_M, Scale_Block_N, Scale_Block_K, 128, 128, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 32, 1, 8>, S<8, 8, 1>, BlkGemmPSched, BlkGemmPVer, float, float, float, float>; - // clang-format on + static constexpr ck::index_t KPerBlock = 64; + using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3< + ALayout, // ALayout + BLayout, // BLayout + CLayout, // CLayout + ADataType, // ADataType + XDataType, // AScaleDataType + BDataType, // BDataType + XDataType, // BScaleDataType + CDataType, // CDataType + AccDataType, // GemmAccDataType + CShuffleDataType, // CShuffleDataType + AElementOp, // AElementwiseOperation + BElementOp, // BElementwiseOperation + CElementOp, // CElementwiseOperation + GemmSpec, // GemmSpec + MXVectorSize, // ScaleBlockSize: Scaling block size + 256, // BlockSize: Thread block size + 128, // MPerBlock + 128, // NPerBlock + KPerBlock, // KPerBlock + 16, // AK1 + 16, // BK1 + 32, // MPerXDL + 32, // NPerXDL + 2, // MXdlPerWave + 2, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 16, // ABlockTransferSrcScalarPerVector + 16, // ABlockTransferDstScalarPerVector_AK1 + false, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 16, // BBlockTransferSrcScalarPerVector + 16, // BBlockTransferDstScalarPerVector_BK1 + false, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + BlkGemmPSched, // BlkGemmPipeSched + BlkGemmPVer, // BlkGemmPipelineVer + ADataType, // ComputeTypeA + BDataType // ComputeTypeB + >; auto M = problem_size.M; auto N = problem_size.N; @@ -156,6 +181,7 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) auto StrideA = problem_size.StrideA; auto StrideB = problem_size.StrideB; auto StrideC = problem_size.StrideC; + auto KBatch = problem_size.KBatch; auto f_host_tensor_descriptor = [](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) { @@ -191,21 +217,27 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); - if(K % Scale_Block_K != 0) + if(K % ScaleBlockSize != 0) { - throw std::runtime_error("wrong! K must be multiple of Scale_Block_K (16 or 32)"); + throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize."); }; - auto Scale_Stride_AM = f_get_default_stride(M, K / Scale_Block_K, StrideA, ALayout{}); - auto Scale_Stride_BN = f_get_default_stride(K / Scale_Block_K, N, StrideB, BLayout{}); + // Hardcode scale layouts as per pipeline assumptions + // TODO: Change default scale layouts to Col for A and Row for B + // TODO: Allow user to specify scale layouts + using AScaleLayout = Row; + using BScaleLayout = Col; - Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); - Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + auto Scale_Stride_AM = f_get_default_stride(M, K / ScaleBlockSize, -1, AScaleLayout{}); + auto Scale_Stride_BN = f_get_default_stride(K / ScaleBlockSize, N, -1, BScaleLayout{}); - Tensor a_m_k_scale( - f_host_tensor_descriptor(M, K / Scale_Block_K, Scale_Stride_AM, ALayout{})); // scales for A - Tensor b_k_n_scale( - f_host_tensor_descriptor(K / Scale_Block_K, N, Scale_Stride_BN, BLayout{})); // scales for B + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, AScaleLayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BScaleLayout{})); + + Tensor a_m_k_scale(f_host_tensor_descriptor( + M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{})); // scales for A + Tensor b_k_n_scale(f_host_tensor_descriptor( + K / ScaleBlockSize, N, Scale_Stride_BN, BScaleLayout{})); // scales for B Tensor c_m_n_host_result( f_host_tensor_descriptor(M, N, StrideC, CLayout{})); // host verification @@ -223,28 +255,37 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) switch(config.init_method) { - case 0: - if(config.verbosity > 0) - { - std::cout << "NOTE: No input data initialization." << std::endl; - } - break; - case 1: - case 2: + case 0: // Initializations for development and debugging ck::utils::FillConstant{ck::type_convert(1.0f)}(a_m_k); - ck::utils::FillConstant{ck::type_convert(0.5f)}(a_m_k_scale); - ck::utils::FillConstant{ck::type_convert(1.0f)}(b_k_n); - ck::utils::FillConstant{ck::type_convert(2.0f)}(b_k_n_scale); + ck::utils::FillConstant{ck::type_convert(2.0f)}(a_m_k_scale); + ck::utils::FillConstant{ck::type_convert(0.5f)}(b_k_n); + ck::utils::FillConstant{ck::type_convert(1.0f)}(b_k_n_scale); if(config.verbosity > 0) { std::cout << "Init A = {1}" << std::endl; - std::cout << "Init A scale = {0.5}" << std::endl; - std::cout << "Init B = {1}" << std::endl; - std::cout << "Init B scale = {2.0}" << std::endl; + std::cout << "Init A scale = {2.0}" << std::endl; + std::cout << "Init B = {0.5}" << std::endl; + std::cout << "Init B scale = {1.0}" << std::endl; std::cout << "Expect C = {K}" << std::endl; } break; + case 1: + ck::utils::FillUniformDistributionIntegerValue{-5.0f, 4.0f}(a_m_k); + ck::utils::FillUniformDistributionIntegerValue{-1.0f, 1.0f}(a_m_k_scale); + + ck::utils::FillUniformDistributionIntegerValue{-4.0f, 5.0f}(b_k_n); + ck::utils::FillUniformDistributionIntegerValue{-1.0f, 1.0f}(b_k_n_scale); + break; + + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); + a_m_k_scale.GenerateTensorValue(GeneratorTensor_3{-1.0f, 1.0f}); + + b_k_n.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); + b_k_n_scale.GenerateTensorValue(GeneratorTensor_3{-1.0f, 1.0f}); + break; + default: if(config.verbosity > 0) { @@ -269,31 +310,31 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) if(config.verbosity > 0) std::cout << "Done." << std::endl; - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto cde_element_op = CDEElementOp{}; + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; - constexpr ck::index_t NumDTensor = DsDataType::Size(); - - // do GEMM + // run GEMM auto device_op = DeviceOpInstance{}; auto invoker = device_op.MakeInvoker(); - auto argument = device_op.MakeArgument(a_device_buf.GetDeviceBuffer(), - b_device_buf.GetDeviceBuffer(), - std::array{}, - c_device_buf.GetDeviceBuffer(), - M, - N, - K, - StrideA, - StrideB, - std::array{}, - StrideC, - a_scale_device_buf.GetDeviceBuffer(), - b_scale_device_buf.GetDeviceBuffer(), - a_element_op, - b_element_op, - cde_element_op); + auto argument = + device_op.MakeArgument(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(a_scale_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + static_cast(b_scale_device_buf.GetDeviceBuffer()), + static_cast(c_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + Scale_Stride_AM, + StrideB, + Scale_Stride_BN, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); if(!device_op.IsSupportedArgument(argument)) { @@ -303,7 +344,10 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) } if(config.verbosity > 0) - std::cout << "Computing GEMM on device..." << std::endl; + { + std::cout << "Computing GEMM on device..." << std::endl << std::endl; + } + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, config.verbosity, 20, 50}); @@ -321,7 +365,7 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) BDataType, CDataType, AccDataType, - float, + XDataType, PassThrough, PassThrough, PassThrough, @@ -347,12 +391,15 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) std::cout << "Comparing results..." << std::endl; } - if(config.init_method == 1) + if(config.init_method == 0) { - res_verified = - res_verified && std::abs(static_cast(K) - c_m_n_device_result(0, 0)) <= 0.0f; - std::cout << "Expected vs Computed: " << 1.0f * K << " vs " << c_m_n_device_result(0, 0) - << ((res_verified) ? " (PASSED!)" : " (FAILED!)") << std::endl; + auto expected = static_cast(K); + auto computed = type_convert(c_m_n_device_result(1, 12)); + + res_verified = res_verified && std::abs(expected - computed) <= 0.0f; + std::cout << "\nExpected vs Computed: " << expected << " vs " << computed + << ((res_verified) ? " (PASSED!)" : " (FAILED!)") << std::endl + << std::endl; } res_verified = res_verified && ck::utils::check_err(c_m_n_device_result, @@ -360,7 +407,7 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) "Error: Incorrect results!"); if(config.verbosity > 0 && res_verified) - std::cout << "Done." << std::endl; + std::cout << "Verification Successful!" << std::endl; } else { @@ -370,17 +417,18 @@ bool run_mx_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) if(config.time_kernel) { - std::size_t flop = std::size_t(2) * M * N * K + M * K + K * N; // GEMM + A scale + B scale + std::size_t flop = std::size_t(2) * M * N * K + + std::size_t(2) * M * N * K / ScaleBlockSize; // GEMM + A scale + B scale std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N + - sizeof(XDataType) * (M * K + K * N) / Scale_Block_K; + sizeof(XDataType) * (M * K + K * N) / ScaleBlockSize; float tflops = static_cast(flop) / 1.E9 / ave_time; float gb_per_sec = num_btype / 1.E6 / ave_time; std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec - << " GB/s" << std::endl; + << " GB/s, " << device_op.GetTypeString() << std::endl; } return res_verified; @@ -393,13 +441,15 @@ template bool run_mx_gemm_example(int argc, char* argv[]) { - ProblemSize problem_size; + ProblemSizeSplitK problem_size; ExecutionConfig config; return parse_cmd_args(argc, argv, problem_size, config) && @@ -410,7 +460,9 @@ bool run_mx_gemm_example(int argc, char* argv[]) ALayout, BLayout, CLayout, - CElementWiseOp, + AElementOp, + BElementOp, + CElementOp, AccDataType, CShuffleDataType, MXVectorSize>(problem_size, config); diff --git a/example/67_gemm_microscaling/gemm_mx_fp8.cpp b/example/67_gemm_microscaling/gemm_mx_fp8.cpp index d2e21698ec..a8ec0e14af 100644 --- a/example/67_gemm_microscaling/gemm_mx_fp8.cpp +++ b/example/67_gemm_microscaling/gemm_mx_fp8.cpp @@ -5,23 +5,24 @@ using ADataType = ck::f8_t; using BDataType = ck::f8_t; -#if 1 -// XXX: MX-native GEMM kernel will work with e8m0_bexp_t scale type -using XDataType = float; -#else -using XDataType = ck::e8m0_bexp_t; -#endif + +// TODO: Enable e8m0_bexp_t and FP8 scale types +using XDataType = ck::half_t; +// using XDataType = ck::e8m0_bexp_t; + +using CDataType = ck::half_t; using AccDataType = float; -using CShuffleDataType = float; -using CDataType = float; +using CShuffleDataType = CDataType; using ALayout = Row; using BLayout = Col; using CLayout = Row; +using AElementOp = PassThrough; // elementwise transformation for A matrix +using BElementOp = PassThrough; // elementwise transformation for B matrix using CElementOp = PassThrough; // elementwise transformation for C matrix -constexpr ck::index_t mx_vector_size = 128; // scaling block size +constexpr ck::index_t mx_vector_size = 32; // scaling block size int main(int argc, char* argv[]) { @@ -32,6 +33,8 @@ int main(int argc, char* argv[]) ALayout, BLayout, CLayout, + AElementOp, + BElementOp, CElementOp, AccDataType, CShuffleDataType, diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp index 45ed6845c2..d7ba2559ea 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -181,6 +181,23 @@ struct BlockwiseGemmXdlops_pipeline_base using Tuple4 = decltype(CalculateAThreadOriginDataIndex()); + /** + * @brief Constructor for BlockwiseGemmXdlops_pipeline_base. + * + * This constructor initializes the thread copy objects for matrices A and B. + * It also performs several compile-time checks to ensure the correctness of the + * matrix tile descriptors. + * + * @param a_origin The origin data index for matrix A. + * @param b_origin The origin data index for matrix B. + * + * @note The constructor includes static assertions to ensure that: + * - The matrix tile descriptors for A and B are known at compile-time. + * - The number of threads in the thread block matches the product of MWaves, NWaves, and + * WaveSize. + * - The dimensions of the block are divisible by the product of the corresponding XDL and + * repeat dimensions. + */ __host__ __device__ BlockwiseGemmXdlops_pipeline_base(Tuple4 a_origin = CalculateAThreadOriginDataIndex(), Tuple4 b_origin = CalculateBThreadOriginDataIndex()) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_selector.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_selector.hpp new file mode 100644 index 0000000000..24f6afc381 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_selector.hpp @@ -0,0 +1,69 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_mx.hpp" + +namespace ck { + +template +constexpr auto BlockGemmMXPipeline_Selector() +{ + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + return BlockwiseGemmXdlops_pipeline_v1_mx{}; + } + else + { + std::cerr << "BlockGemmPipeline configuration is not available" << std::endl; + } +} + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_mx.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_mx.hpp new file mode 100644 index 0000000000..628dafb063 --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_mx.hpp @@ -0,0 +1,617 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp" + +namespace ck { + +// Naive pipeline with lowest resource request per WGP +// GlobalPrefetchStages: 1 +// LocalPreFillStages: 1 +// LocalPreFetchStages: 0 +// LocalSharedMemoryBuffer: 1 + +template +struct BlockwiseGemmXdlops_pipeline_v1_mx +{ +}; + +template +struct BlockwiseGemmXdlops_pipeline_v1_mx + : BlockwiseGemmXdlops_pipeline_base + +{ + using Base = BlockwiseGemmXdlops_pipeline_base; + using Base::I0; + using Base::I1; + using Base::KRepeat; + using Base::MWaves; + using Base::NWaves; + using Base::WaveSize; + using Base::xdlops_gemm; + + using Base::CalculateCThreadOriginDataIndex; + using Base::CalculateCThreadOriginDataIndex8D; + using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetCThreadBuffer; + using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4; + using Base::GetWaveIdx; + using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2; + using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2; + + using Base::a_block_desc_m0_m1_m2_k; + using Base::b_block_desc_n0_n1_n2_k; + + using Base::AMmaKStride; + using Base::BMmaKStride; + + using Tuple4 = typename Base::Tuple4; + + static constexpr index_t PrefetchStages = 1; + static constexpr index_t PrefillStages = 1; + static constexpr index_t GlobalBufferNum = 1; + + static constexpr auto ScalesPerKBlockSize = + KPerBlock / ScaleBlockSize; // How many mx-vectors per K block size + + __host__ static constexpr bool BlockHasHotloop(index_t num_loop) + { + return num_loop > PrefetchStages; + } + + __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop) + { + ignore = num_loop; + return TailNumber::Full; + } + + __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + + const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex(); + + return make_tuple(0, waveId_m, xdlops_a_idx[I1], xdlops_gemm.KPerXdlops * xdlops_a_idx[I0]); + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_n = wave_idx[I1]; + + const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex(); + + return make_tuple(0, waveId_n, xdlops_b_idx[I1], xdlops_gemm.KPerXdlops * xdlops_b_idx[I0]); + } + + /** + * @brief Constructor for BlockwiseGemmXdlops_pipeline_v1_mx. + * + * The primary purpose of this constructor is to modify default initialization of the base class + * with the origin data index suitable for microscaling. + * + * @param a_origin The origin data index for matrix A. + * @param b_origin The origin data index for matrix B. + * + */ + __host__ __device__ + BlockwiseGemmXdlops_pipeline_v1_mx(Tuple4 a_origin = CalculateAThreadOriginDataIndex(), + Tuple4 b_origin = CalculateBThreadOriginDataIndex()) + : Base(a_origin, b_origin) + { + } + + template + __device__ void Run( + // ABlockCopy + const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + // BBlockCopy + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + // CThread + CThreadBuffer& c_thread_buf, + // A and B scales + const AScaleGridDesc& a_scale_grid_desc, + AScaleThreadTransfer& a_scale_thread_copy, + const AScaleGridBuffer& a_scale_grid_buf, + const BScaleGridDesc& b_scale_grid_desc, + BScaleThreadTransfer& b_scale_thread_copy, + const BScaleGridBuffer& b_scale_grid_buf, + index_t num_loop) const + { + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + auto a_scale_thread_buf = make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize()); + + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + static_assert(xdlops_gemm.mfma_instr.num_groups_per_blk * + xdlops_gemm.mfma_instr.group_size == + xdlops_gemm.GetRegSizePerXdlops(), + "Assume num_regs_per_blk == num_groups_per_blk * group_size"); + + // Prefetch a_scales + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, xdlops_gemm.mfma_instr.num_groups_per_blk, 1>{}([&](auto g) { + auto a_scale_thread_buf_group = + make_static_buffer( + a_scale_thread_desc_group.GetElementSpaceSize()); + + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc_group, + make_tuple(I0, I0), + a_scale_thread_buf_group); + + static_for<0, xdlops_gemm.mfma_instr.group_size, 1>{}([&](auto i) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, g, i)); + a_scale_thread_buf(Number{}) = + a_scale_thread_buf_group[Number{}]; + }); + // go to the next group + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(2 * xdlops_gemm.mfma_instr.group_size, 0)); + }); // g + + // restore row id and advance to the next scale + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(-2 * xdlops_gemm.mfma_instr.group_size * + xdlops_gemm.mfma_instr.num_groups_per_blk, + 1)); + }); // k0 + + // restore column id and advance to the next set of rows + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize)); + }); // m0 + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, + make_multi_index(-MPerBlock, ScalesPerKBlockSize)); + + // Prefetch b_scales + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, I0), + b_scale_thread_buf); + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(NWaves * NPerXDL, 0)); + }); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(-NPerBlock, ScalesPerKBlockSize)); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // Initialize C + c_thread_buf.Clear(); + + auto c_thread_buf_per_scale = remove_cvref_t(); + + // main body + if constexpr(HasMainLoop) + { + // loop over k with the step KPerBlock + index_t i = 0; + do + { + // ------------------------------------------------------------------------------------------- + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto a_k_step = k * AMmaKStride * KPack / xdlops_gemm.K1PerXdlops; + constexpr auto b_k_step = k * BMmaKStride * KPack / xdlops_gemm.K1PerXdlops; + + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + c_thread_buf_per_scale.Clear(); + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + // MFMA accumulation + // m = 1:MPerXDL + // n = 1:NPerXDL + // k = 1:KPack + // c(m,n) += a(m,k)*b(k,n) + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(I0)); + + // one scale per k0 + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0)); + + static_for<0, xdlops_gemm.mfma_instr.num_groups_per_blk, 1>{}( + [&](auto g) { + static_for<0, xdlops_gemm.mfma_instr.group_size, 1>{}( + [&](auto r) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset( + make_tuple(m0, k0, g, r)); + + constexpr auto reg_offset = + g * xdlops_gemm.mfma_instr.group_size + r; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset( + make_tuple(m0, n0, reg_offset)); + + c_thread_buf(Number{}) += + c_thread_buf_per_scale[Number{}] * + type_convert( + b_scale_thread_buf[Number{}]) * + type_convert( + a_scale_thread_buf[Number{}]); + }); + }); + }); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, xdlops_gemm.mfma_instr.num_groups_per_blk, 1>{}([&](auto g) { + auto a_scale_thread_buf_group = + make_static_buffer( + a_scale_thread_desc_group.GetElementSpaceSize()); + + a_scale_thread_copy.Run(a_scale_grid_desc, + a_scale_grid_buf, + a_scale_thread_desc_group, + make_tuple(I0, I0), + a_scale_thread_buf_group); + + static_for<0, xdlops_gemm.mfma_instr.group_size, 1>{}([&](auto r) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, g, r)); + a_scale_thread_buf(Number{}) = + a_scale_thread_buf_group[Number{}]; + }); + // go to the next group + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(2 * xdlops_gemm.mfma_instr.group_size, 0)); + }); // g + + // restore row id and advance to the next scale + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(-2 * xdlops_gemm.mfma_instr.group_size * + xdlops_gemm.mfma_instr.num_groups_per_blk, + 1)); + }); // k0 + + // restore column id and advance to the next set of rows + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, + make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize)); + }); // m0 + + // restore row id and advance to the next set of scales + a_scale_thread_copy.MoveSrcSliceWindow( + a_scale_grid_desc, make_multi_index(-MPerBlock, ScalesPerKBlockSize)); + + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, I0), + b_scale_thread_buf); + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + make_multi_index(NWaves * NPerXDL, 0)); + }); + // NWaves * NPerXDL * NRepeat == NPerBlock + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize)); + + block_sync_lds(); + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + i += 1; + + } while(i < (num_loop - 1)); + } + + // tail + if constexpr(TailNum == TailNumber::Full) + { + block_sync_lds(); + + static_for<0, KRepeat, 1>{}([&](auto k) { + constexpr auto a_k_step = k * AMmaKStride * KPack / xdlops_gemm.K1PerXdlops; + constexpr auto b_k_step = k * BMmaKStride * KPack / xdlops_gemm.K1PerXdlops; + + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, I0), + b_thread_buf); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + c_thread_buf_per_scale.Clear(); + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack, 1>{}([&](auto ik) { + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}]; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}]; + }); + + using mfma_input_type = + typename vector_type::type; + + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf_per_scale.GetVectorTypeReference(I0)); + + // one scale per k0 + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0)); + + static_for<0, xdlops_gemm.mfma_instr.num_groups_per_blk, 1>{}([&](auto g) { + static_for<0, xdlops_gemm.mfma_instr.group_size, 1>{}([&](auto r) { + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, g, r)); + + constexpr auto reg_offset = + g * xdlops_gemm.mfma_instr.group_size + r; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, reg_offset)); + + c_thread_buf(Number{}) += + c_thread_buf_per_scale[Number{}] * + type_convert( + b_scale_thread_buf[Number{}]) * + type_convert( + a_scale_thread_buf[Number{}]); + }); + }); + }); + }); + }); + } + } + + // TODO: make this field protected when a_scale_thread_copy_ is moved here + static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})); + + // Is used to copy data from a_scale_grid to a_scale_thread + static constexpr auto a_scale_thread_desc_group = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number<1>{})); + + // TODO: make this field protected when b_scale_thread_copy_ is moved here + static constexpr auto b_scale_thread_desc = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number{})); + + protected: + using Base::a_thread_copy_; + using Base::a_thread_desc_; + using Base::b_thread_copy_; + using Base::b_thread_desc_; + using Base::c_thread_desc_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_mx.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_mx.hpp new file mode 100644 index 0000000000..e89185a35c --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/device_gemm_mx.hpp @@ -0,0 +1,50 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/tensor_operation/gpu/device/device_base.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmMX : public BaseOperator +{ + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_a_scale, + const void* p_b, + const void* p_b_scale, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideAScale, + ck::index_t StrideB, + ck::index_t StrideBScale, + ck::index_t StrideC, + ck::index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp new file mode 100644 index 0000000000..34df9a1d7b --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp @@ -0,0 +1,877 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" + +#include "ck/host_utility/flush_cache.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_mx.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +/** + * \brief WIP: Implements XDL CShuffle V3 GEMM for microscale-compliant data types + * + * This class is a work-in-progress implementation of the XDL CShuffle V3 GEMM for + * microscale-compliant data types. + * + * Assumptions: + * - A and B data types are compliant with the OCP Microscaling Formats (MX) Specification + * - Each scale applies to ScaleBlockSize elements in K direction + * - A scale matrix is row-major + * - B scale matrix is column-major + * - Scale data types must have get_exponent_value() specialization, whereas lowest 8 bits of the + * exponent will be interpreted as conventional biased Float32 exponent (E8M0) + * + * Tunable parameters. + * The CK instance includes a series of tunable template parameters to control the parallel + * granularity of the workload to achieve load balancing on different hardware platforms. These + * parameters include Block Size, M/N/K Per Block, M/N per XDL, AK1, BK1, etc. + * - Block Size determines the number of threads in the thread block. + * - M/N/K Per Block determines the size of tile that each thread block is responsible for + * calculating. + * - M/N Per XDL refers to M/N size for Instinct accelerator Matrix Fused Multiply Add (MFMA) + * instructions operating on a per-wavefront basis. + * - A/B K1 is related to the data type. It can be any value ranging from 1 to K Per Block. To + * achieve the optimal load/store performance, 128bit per load is suggested. In addition, the A/B + * loading parameters must be changed accordingly to match the A/B K1 value; otherwise, it will + * result in compilation errors. + * + * Conditions for achieving computational load balancing on different hardware platforms can vary. + * + * Serialized version of the algorithm: + * \code + * // E = A * B + C + * // Loop over E[MPerBlock,NPerBlock] tiles + * for(int mb = 0; mb < M; mb += MPerBlock){ + * for(int nb = 0; nb < N; nb += NPerBlock){ + * // initialize E[MPerBlock,NPerBlock] tile + * for(int mt = mb; mt < mb + MPerBlock; mt++){ + * for(int nt = nb; nt < nb + NPerBlock; nt++){ + * E[mt,nt] = C[mt,nt]; + * } + * } + * + * // multiply-accumulate per tile + * for(int kb = 0; kb < K; kb += KPerBlock){ + * for(int m0 = mb; m0 < mb + MPerBlock; m0 += MWaves * MPerXDL){ + * for(int n0 = nb; n0 < nb + NPerBlock; n0 += NWaves * NPerXDL){ + * for(int mw = m0; mw < m0 + MWaves * MPerXDL; mw += MPerXDL){ + * for(int nw = n0; nw < n0 + NWaves * NPerXDL; nw += NPerXDL){ + * for(int k0 = kb; k0 < kb + KPerBlock; k0 += mfma.num_input_blks*KPack){ + * // MFMA accumulation for multirate instructions + * for(int k_pack = k0; k_pack < k0 + mfma.num_input_blks*KPack; k_pack += KPack){ + * for(int k_mfma = k_pack; k_mfma < k_pack + KPack; k_mfma += mfma.k_per_blk){ + * // MFMA instruction + * for(int m = mw; m < mw + MPerXDL; m++){ + * for(int n = nw; n < nw + NPerXDL; n++){ + * for(int k = k_mfma; k < k_mfma + mfma.k_per_blk; k++){ + * E[m,n] += A[m,k] * B[k,n]; + * } + * } + * } + * } + * } + * } + * } + * } + * } + * } + * } + * } + * } + * \endcode + * + */ +template +struct DeviceGemmMX_Xdl_CShuffleV3 : public DeviceGemmMX +{ + // GridwiseGemm + using GridwiseGemm = GridwiseGemmMX_xdl_cshuffle_v3< + ALayout, + BLayout, + CLayout, + ADataType, + AScaleDataType, + BDataType, + BScaleDataType, + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + ScaleBlockSize, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB>; + + using Argument = typename GridwiseGemm::Argument; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + Argument arg_ = arg; + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); + + auto size_a_buffer = + a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType); + auto size_b_buffer = + b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType); + + ck::utility::RotatingMemWrapper rotating_mem( + arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, + 0, + arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_); + } + else + { + if(arg.KBatch > 1) + hipGetErrorString(hipMemsetAsync(arg.p_c_grid, + 0, + arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + }; + + // TODO: Check if this is the right algorithm for minimum_occupancy + constexpr index_t minimum_occupancy = + BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave + ? (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 && + MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) <= 128 * 128 * 64 * 2) + ? 2 + : 1 + : 2; + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(arg.KBatch > 1) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + // Tail number could be One to Seven + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Two>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Three>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Four>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Five>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Six>; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Seven>; + Run(kernel); + } + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Full) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Three) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Four) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Five) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + + if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Seven) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + } + // Tail number could be Odd or Even + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_2lds; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3_2lds; + Run(kernel); + } + } + } + else + { + if(arg.KBatch > 1) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(arg.KBatch > 1) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + static_assert((is_same_v || is_same_v || + is_same_v || is_same_v || + is_same_v)&&(is_same_v || + is_same_v || + is_same_v || + is_same_v || + is_same_v), + "Only microscaling formats are supported for ADataType and BDataType"); + + static_assert(ScaleBlockSize == 32, "Only ScaleBlockSize 32 is supported"); + + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_xdl_supported()) + { + return false; + } + + if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) + { + return false; + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const ADataType* p_a, + const AScaleDataType* p_a_scale, + const BDataType* p_b, + const BScaleDataType* p_b_scale, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideScaleA, + index_t StrideB, + index_t StrideScaleB, + index_t StrideC, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_a_scale, + p_b, + p_b_scale, + p_c, + M, + N, + K, + StrideA, + StrideScaleA, + StrideB, + StrideScaleB, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_a_scale, + const void* p_b, + const void* p_b_scale, + void* p_c, + ck::index_t M, + ck::index_t N, + ck::index_t K, + ck::index_t StrideA, + ck::index_t StrideScaleA, + ck::index_t StrideB, + ck::index_t StrideScaleB, + ck::index_t StrideC, + ck::index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_a_scale), + static_cast(p_b), + static_cast(p_b_scale), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideScaleA, + StrideB, + StrideScaleB, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemmMX_Xdl_CShuffleV3" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"< +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + GridwiseGemm::template Run( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + p_shared, + karg); + +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + // Pass two lds pointer is the key to tell compiler that ds_read/write + // operate on different lds chunk at same time without order dependecy + __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + GridwiseGemm::template Run_2Lds( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset, + p_shared_0, + p_shared_1, + karg); + +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} + +template +struct GridwiseGemmMX_xdl_cshuffle_v3 +{ + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); + static constexpr bool is_single_rate_mfma = + ((is_same::value || is_same::value) && + lcm_AK1_BK1 <= 4) + ? true + : false; + static constexpr index_t KPack = + math::max(lcm_AK1_BK1, + MfmaSelector:: + selected_mfma.k_per_blk); + + using ThisThreadBlock = ThisThreadBlock; + + static constexpr index_t APackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + static constexpr index_t BPackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + + __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + static_assert(!(is_same_v, pk_i4_t> && + GemmSpec != GemmSpecialization::Default), + "pk_i4_t does not support padding"); + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + if constexpr(!PermuteB) + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // Weight Tile Permute + constexpr index_t BK01 = KPerBlock / BK1Value; + // const index_t BK00 = BK0 / BK01; + const index_t BK0_ = StrideB / BK1Value; + const index_t BK00 = BK0_ / BK01; + + const auto b_grid_desc_bk00_n_bk01_bk1_permute = + make_naive_tensor_descriptor_packed(make_tuple(BK00, N, BK01, BK1Value)); + + const auto b_grid_desc_bk0_n_bk1_permute = transform_tensor_descriptor( + b_grid_desc_bk00_n_bk01_bk1_permute, + make_tuple(make_merge_transform(make_tuple(BK00, BK01)), + make_pass_through_transform(make_tuple(N)), + make_pass_through_transform(BK1Value)), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_grid_desc_bk0_n_bk1_permute; + } + } + } + + template + __host__ __device__ static constexpr auto + MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto + MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + } + + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); +#if 0 + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } +#endif + } + + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideScaleA_, + index_t StrideB_, + index_t StrideScaleB_, + index_t StrideC_, + index_t KBatch_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideScaleA{StrideScaleA_}, + StrideB{StrideB_}, + StrideScaleB{StrideScaleB_}, + StrideC{StrideC_}, + KBatch{KBatch_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, KBatch_)}, + KPadded{CalculateKPadded(K_, KBatch_)}, + AK0{CalculateAK0Padded(K_, KBatch_)}, + BK0{CalculateBK0Padded(K_, KBatch_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SScaleA:" << StrideScaleA << ", " + << "SB:" << StrideB << ", " + << "SScaleB:" << StrideScaleB << ", " + << "SC:" << StrideC << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideScaleA; + index_t StrideB; + index_t StrideScaleB; + index_t StrideC; + index_t KBatch; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const ADataType* p_a_grid_, + const AScaleDataType* p_a_scale_grid_, + const BDataType* p_b_grid_, + const BScaleDataType* p_b_scale_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideScaleA_, + index_t StrideB_, + index_t StrideScaleB_, + index_t StrideC_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_, + bool is_reduce_ = false) + : Problem{M_, + N_, + K_, + StrideA_, + StrideScaleA_, + StrideB_, + StrideScaleB_, + StrideC_, + k_batch_}, + p_a_grid{p_a_grid_}, + p_a_scale_grid{p_a_scale_grid_}, + p_b_grid{p_b_grid_}, + p_b_scale_grid{p_b_scale_grid_}, + p_c_grid{p_c_grid_}, + a_element_op{a_element_op_}, + b_element_op{b_element_op_}, + c_element_op{c_element_op_}, + is_reduce(is_reduce_) + { + } + + __host__ __device__ inline bool IsReduceAdd() const + { + return (Problem::KBatch > 1) && is_reduce; + } + + __host__ __device__ inline bool IsAtomicAdd() const + { + return (Problem::KBatch > 1) && (!is_reduce); + } + + const ADataType* p_a_grid; + const AScaleDataType* p_a_scale_grid; + const BDataType* p_b_grid; + const BScaleDataType* p_b_scale_grid; + CDataType* p_c_grid; + + const AElementwiseOperation a_element_op; + const BElementwiseOperation b_element_op; + const CElementwiseOperation c_element_op; + bool is_reduce; + }; + + struct SplitKBatchOffset + { + + __device__ SplitKBatchOffset(Argument& karg, index_t k_id) + { + if constexpr(is_same_v) + { + a_k_split_offset = k_id * karg.KRead / APackedSize; + } + else if constexpr(is_same_v) + { + a_k_split_offset = k_id * karg.KRead * karg.StrideA; + } + + if constexpr(is_same_v) + { + b_k_split_offset = k_id * karg.KRead * karg.StrideB; + } + else if constexpr(is_same_v) + { + if constexpr(!PermuteB) + { + b_k_split_offset = k_id * karg.KRead / BPackedSize; + } + else + { + const int k0_offset = karg.KRead * karg.N; + b_k_split_offset = k_id * k0_offset / BPackedSize; + } + } + + // Calculate A scale offset + if constexpr(is_same_v) + { + a_scale_k_split_offset = k_id * karg.KRead / ScaleBlockSize; + } + else if constexpr(is_same_v) + { + a_scale_k_split_offset = k_id * karg.KRead / ScaleBlockSize * karg.StrideScaleA; + } + + // Calculate B scale offset + if constexpr(is_same_v) + { + b_scale_k_split_offset = k_id * (karg.KRead / ScaleBlockSize) * karg.StrideScaleB; + } + else if constexpr(is_same_v) + { + b_scale_k_split_offset = k_id * karg.KRead / ScaleBlockSize; + } + + if(k_id < (karg.KBatch - 1)) + { + karg.K = karg.KRead; + } + else + { + karg.K = karg.K - karg.KRead * (karg.KBatch - 1); + } + + if(karg.IsReduceAdd()) + { + c_reduce_offset = k_id * karg.M * karg.N; + } + else + { + c_reduce_offset = 0; + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + index_t a_scale_k_split_offset; // New member for scale matrix offset + index_t b_scale_k_split_offset; // New member for scale matrix offset + index_t c_reduce_offset; + }; + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(ADataType) / APackedSize; + constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize; + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + AK0Number * Number{}, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_ak0_mldslayer_m_ak1, + make_tuple(make_pass_through_transform(AK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = BlockwiseGemmPipe::WaveSize / MPerXdl; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(ADataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128) + ? 1 + : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0 + ? M0 + : 128 / (AK1Number * MPerXdl * sizeof(ADataType))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + } + else if constexpr(is_same::value) + { + // NLdsLayer * K0 as logical Bank + constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(BDataType) / BPackedSize; + constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize; + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + BK0Number * Number{}, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_bk0_nldslayer_n_bk1, + make_tuple(make_pass_through_transform(BK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + else // RowMajor B + { + constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N1 = NPerBlock / N0; + + constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); + constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; + constexpr auto KThreadRead = BlockwiseGemmPipe::WaveSize / NPerXdl; + constexpr auto K0PerThreadRead = BK0Number / KThreadRead; + + constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) + ? 1 + : 128 / (BK1Number * N0 * sizeof(BDataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=n0 + constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128) + ? 1 + : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0 + ? N0 + : 128 / (BK1Number * NPerXdl * sizeof(BDataType))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + BK1Number)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + } + + __device__ static constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock() + { + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock; + } + + using BlockwiseGemmPipe = + remove_cvref_t())>; + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned * sizeof(ADataType) / APackedSize + + b_block_space_size_aligned * sizeof(BDataType) / BPackedSize), + c_block_size * sizeof(CShuffleDataType)); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ static constexpr bool CheckValidity(const Argument& karg) + { + static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) && + (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, + "Invalid tuning param!"); + + static_assert(KPerBlock % ScaleBlockSize == 0, + "KPerBlock should be multiple of ScaleBlockSize"); + + static_assert(KPerBlock / ScaleBlockSize == BlockwiseGemmPipe::KRepeat, + "Single call to xdlops_gemm::Run should process exactly ScaleBlockSize " + "elements in k dimension"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + !(is_same::value)) + { + if(!(karg.M % MPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + (is_same::value)) + { + if(!(karg.N % NPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + auto K_t = karg.KBatch * KPerBlock; + if(!(karg.K % K_t == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = karg.KBatch * KReadVec; + auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; + if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + else + { + if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(is_same, half_t>::value || + is_same, float>::value || + is_same, bhalf_t>::value || + is_same, int32_t>::value)) + { + if(!karg.IsReduceAdd()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; + } + if(karg.KBatch > 1) + { + return false; + } + } + } + + // check gridwise gemm pipeline + const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); + + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + template + __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit; + + template + __device__ static void Run(const ADataType* p_a_grid, + const AScaleDataType* p_a_scale_grid, + const BDataType* p_b_grid, + const BScaleDataType* p_b_scale_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const AScaleGridDesc_AM_AK& a_scale_grid_desc_am_ak, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // A Scale buffer + const auto a_scale_grid_buf = make_dynamic_buffer( + p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); + + // B Scale buffer + const auto b_scale_grid_buf = make_dynamic_buffer( + p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + reinterpret_cast(static_cast(p_shared) + a_block_space_size_aligned * + sizeof(ADataType) / + APackedSize), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma; + static constexpr auto KPerXdlops = mfma.GetKPerXdlops(); + static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops(); + static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops; + static constexpr auto KPerThread = KPerBlock / K0PerXdlops; + + // NXdlPerWave == NRepeat + // MXdlPerWave == MRepeat + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + + // Initial thread mapping for MPerXdl=NPerXdl=32 and MPerBlock=NPerBlock=128 MWaves=NWaves=2 + // tId in [ 0, 63] m x n = [ 0, 31] x [ 0, 31] waveId = [0, 0] + // tId in [ 64, 127] m x n = [ 0, 31] x [32, 63] waveId = [0, 1] + // tId in [128, 191] m x n = [32, 63] x [ 0, 31] waveId = [1, 0] + // tId in [192, 255] m x n = [32, 63] x [32, 63] waveId = [1, 1] + + auto a_thread_offset_m = + MPerXdl * ((get_thread_local_1d_id() / BlockwiseGemmPipe::WaveSize) / MWaves) + + mfma.selected_mfma.group_size * + ((get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize) / MPerXdl); + auto a_thread_offset_k = KPerThread * (get_thread_local_1d_id() % MPerXdl) / MPerXdl; + + auto b_thread_offset_n = + get_thread_local_1d_id() % NPerXdl + + (get_thread_local_1d_id() / BlockwiseGemmPipe::WaveSize) % NWaves * NPerXdl; + auto b_thread_offset_k = KPerThread * (get_thread_local_1d_id() % NPerXdl) / NPerXdl; + + auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< + AScaleDataType, + AScaleDataType, + decltype(a_scale_grid_desc_am_ak), // SrcDesc + decltype(BlockwiseGemmPipe::a_scale_thread_desc_group), // DstDesc + Sequence, // SliceLengths + Sequence<0, 1>, // DimAccessOrder + 0, // SrcVectorDim + 1, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>(a_scale_grid_desc_am_ak, + make_multi_index(block_m_id * MPerBlock + a_thread_offset_m, + a_thread_offset_k / ScaleBlockSize)); + + auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< + BScaleDataType, + BScaleDataType, + decltype(b_scale_grid_desc_bn_ak), + decltype(BlockwiseGemmPipe::b_scale_thread_desc), + Sequence<1, BlockwiseGemmPipe::KRepeat>, // SliceLengths + Sequence<0, 1>, // DimAccessOrder + 1, // SrcVectorDim + BlockwiseGemmPipe::KRepeat, // SrcScalarPerVector + 1, + false>(b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock + b_thread_offset_n, + b_thread_offset_k / ScaleBlockSize)); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + a_scale_grid_desc_am_ak, + a_scale_thread_copy, + a_scale_grid_buf, + b_scale_grid_desc_bn_ak, + b_scale_thread_copy, + b_scale_grid_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } + + template + __device__ static void Run(const ADataType* p_a_grid, + const AScaleDataType* p_a_scale_grid, + const BDataType* p_b_grid, + const BScaleDataType* p_b_scale_grid, + CDataType* p_c_grid, + void* p_shared, + const Problem& problem) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + // A Scale grid + const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor( + make_tuple(problem.M, math::integer_divide_ceil(problem.K, ScaleBlockSize)), + make_tuple(problem.StrideScaleA, 1)); + + // B Scale grid transposed + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( + make_tuple(problem.N, math::integer_divide_ceil(problem.K, ScaleBlockSize)), + make_tuple(problem.StrideScaleB, 1)); + + Run(p_a_grid, + p_a_scale_grid, + p_b_grid, + p_b_scale_grid, + p_c_grid, + p_shared, + problem, + a_grid_desc_ak0_m_ak1, + a_scale_grid_desc_am_ak, + b_grid_desc_bk0_n_bk1, + b_scale_grid_desc_bn_ak, + c_grid_desc_mblock_mperblock_nblock_nperblock); + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const AScaleDataType* p_a_scale_grid, + const BDataType* p_b_grid, + const BScaleDataType* p_b_scale_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const AScaleGridDesc_AM_AK& a_scale_grid_desc_am_ak, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock) + { + ignore = p_a_scale_grid; + ignore = a_scale_grid_desc_am_ak; + + // TODO: Implement 2 LDS version + static_assert(false, "Not implemented"); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + // B Scale buffer + const auto b_scale_grid_buf = make_dynamic_buffer( + p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_ping = make_dynamic_buffer( + bit_cast(static_cast(p_shared_0) + + a_block_space_size_aligned * sizeof(ADataType) / APackedSize), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf_pong = make_dynamic_buffer( + bit_cast(bit_cast(p_shared_1) + + a_block_space_size_aligned * sizeof(ADataType) / APackedSize), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + // B scale + static constexpr auto mfma = + MfmaSelector{}; + static constexpr auto KPerXdlops = mfma.GetKPerXdlops(); + static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops(); + static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops; + static constexpr auto KPerThread = KPerBlock / K0PerXdlops; + + const index_t ScaleSliceSizeN = NXdlPerWave; + static constexpr auto ScaleSliceSizeK = (KPerThread + ScaleBlockSize - 1) / ScaleBlockSize; + static constexpr auto KBlockScaleSliceSizeK = + (KPerBlock + ScaleBlockSize - 1) / ScaleBlockSize; + + constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + + auto b_thread_offset_n = + get_thread_local_1d_id() % NPerXdl + + (get_thread_local_1d_id() / BlockwiseGemmPipe::WaveSize) % NWaves * NPerXdl; + auto b_thread_offset_k = + (get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize) / NPerXdl * KPerThread; + + auto b_scale_thread_copy = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0, 1>, + 1, + ScaleSliceSizeK, + 1, + false>( + b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock + b_thread_offset_n, + b_thread_offset_k / ScaleBlockSize)); + + constexpr auto b_scale_thread_slice_copy_step = + make_tuple(make_multi_index(NWaves * NPerXdl, 0), + make_multi_index(-NPerBlock, 0), + make_multi_index(-NPerBlock, KBlockScaleSliceSizeK)); + + blockwise_gemm_pipeline.template Run( + a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + b_scale_grid_desc_bn_ak, + b_scale_thread_desc, + b_scale_thread_copy, + b_scale_grid_buf, + b_scale_thread_slice_copy_step, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared_0), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } + + template + __device__ static void Run_2Lds(const ADataType* p_a_grid, + const BDataType* p_b_grid, + const BScaleDataType* p_b_scale_grid, + CDataType* p_c_grid, + void* p_shared_0, + void* p_shared_1, + const Problem& problem) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( + make_tuple(problem.N, math::integer_divide_ceil(problem.K, ScaleBlockSize)), + make_tuple(problem.StrideScaleB, 1)); + + Run_2Lds(p_a_grid, + p_b_grid, + p_b_scale_grid, + p_c_grid, + p_shared_0, + p_shared_1, + problem, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + b_scale_grid_desc_bn_ak, + c_grid_desc_mblock_mperblock_nblock_nperblock); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index a7efea277f..0310fe37a0 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -189,15 +189,36 @@ struct ThreadwiseTensorSliceTransfer_v1r3 const ElementwiseOperation element_op_; }; // namespace ThreadwiseTensorSliceTransfer_v1r3 -// Assume: -// 1. src: -// 1. SrcDesc is not known at compile-time -// 2. SrcBuffer is DynamicBuffer -// 3. src_slice_origin_idx is not known at compile-time -// 2. dst: -// 1. DstDesc is known at compile-time -// 2. DstBuffer is StaticBuffer -// 3. dst_slice_origin_idx is known at compile-time +/** + * @brief Helper structure that facilitates transfer of source (grid) data to destination threads. + * + * @details The following assumptions are made: + * - For Source (Grid) Data: + * 1. The source tensor descriptor SrcDesc is not known at compile-time. + * 2. The source buffer is a dynamic buffer. + * 3. The source slice origin index src_slice_origin_idx is not known at compile-time. + * - For Destination (Thread) Data: + * 1. The destination tensor descriptor DstDesc is known at compile-time. + * 2. The destination buffer dst_buf is a static buffer. + * 3. The destination slice origin index dst_slice_origin_idx is known at compile-time. + * + * @tparam SrcData The data type of the source tensor. + * @tparam DstData The data type of the destination tensor. + * @tparam SrcDesc The descriptor type of the source tensor. + * @tparam DstDesc The descriptor type of the destination tensor. + * @tparam SliceLengths The lengths of the slice to be transferred. + * @tparam DimAccessOrder The order of dimension access for the space-filling curve. + * @tparam SrcVectorDim The dimension along which vectorized access is performed in the source + * tensor. + * @tparam SrcScalarPerVector The number of scalar elements per vector in the source tensor. + * @tparam SrcScalarStrideInVector The stride of scalar elements within a vector in the source + * tensor. + * @tparam SrcResetCoordinateAfterRun controls whether source coordinate is restored after each Run + * or rolled back one step in MoveSrcSliceWindow + * @tparam InvalidElementAsNaN Whether to fill invalid elements with NaN (only applicable for + * floating-point types). + * + */ template static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ??? static constexpr index_t m_per_blk = 32; // from the instruction static constexpr index_t n_per_blk = 32; // from the instruction - static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 64 / num_input_blks + static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks static constexpr bool is_k_reduction = true; // ??? // clang-format on @@ -817,7 +817,7 @@ struct mfma_type static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ??? static constexpr index_t m_per_blk = 16; // from the instruction static constexpr index_t n_per_blk = 16; // from the instruction - static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 128 / num_input_blks + static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks static constexpr bool is_k_reduction = true; // ??? // clang-format on @@ -841,7 +841,7 @@ struct mfma_type static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ??? static constexpr index_t m_per_blk = 32; // from the instruction static constexpr index_t n_per_blk = 32; // from the instruction - static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 64 / num_input_blks + static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks static constexpr bool is_k_reduction = true; // ??? // clang-format on @@ -870,7 +870,7 @@ struct mfma_type static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ??? static constexpr index_t m_per_blk = 16; // from the instruction static constexpr index_t n_per_blk = 16; // from the instruction - static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 128 / num_input_blks + static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks static constexpr bool is_k_reduction = true; // ??? // clang-format on