From 5e2bd20672b24e2b0a7d2413f33bf4ff73e3fb62 Mon Sep 17 00:00:00 2001 From: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> Date: Tue, 15 Apr 2025 17:17:07 -0600 Subject: [PATCH] MX GEMM - New GEMM pipeline for MX data types (#2059) * Allow selection of mfma_scale instructions * Read B tensor from LDS to VGPR in chunks of 16 in MFMA order * Add constexpr and synchronize return type for `get_exponent_value` * Pass scales by reference and add comments to `mfma_scale_f32_32x32x64` * Add support for microscaling instructions in `XdlopsGemm` * Fix `mfma_scale_f32_16x16x128f8f6f4` wrapper * Remove software implementation of MX GEMM * Make interface of `intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>` consistent with the other scale instruction * Update README * Updated CHANGELOG * Remove unused static methods [ROCm/composable_kernel commit: 7106976a72897f44b05260bd1ae1f70b319a4e75] --- CHANGELOG.md | 1 + example/67_gemm_microscaling/CMakeLists.txt | 9 +- example/67_gemm_microscaling/README.md | 8 +- .../67_gemm_microscaling/gemm_mx_common.hpp | 79 +-- example/67_gemm_microscaling/gemm_mx_fp8.cpp | 98 ++++ .../gemm_mx_fp8_e8m0_scale.cpp | 42 -- .../gemm_mx_fp8_fp16_scale.cpp | 42 -- .../gemm_mx_fp8_fp8_scale.cpp | 42 -- ...blockwise_gemm_mx_pipeline_xdlops_base.hpp | 363 ++++++++++++ ...kwise_gemm_pipeline_xdlops_mx_selector.hpp | 35 +- .../blockwise_gemm_pipeline_xdlops_v1_mx.hpp | 546 +++++++++--------- .../impl/device_gemm_xdl_cshuffle_v3_mx.hpp | 14 +- .../grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp | 122 ++-- .../threadwise_tensor_slice_transfer.hpp | 3 +- .../tensor_operation/gpu/warp/xdlops_gemm.hpp | 89 ++- include/ck/utility/amd_xdlops.hpp | 16 +- include/ck/utility/e8m0.hpp | 4 +- include/ck/utility/mxfp_utils.hpp | 4 +- test/mx_mfma_op/mx_mfma_op.hpp | 98 ++-- 19 files changed, 1007 insertions(+), 608 deletions(-) create mode 100644 example/67_gemm_microscaling/gemm_mx_fp8.cpp delete mode 100644 example/67_gemm_microscaling/gemm_mx_fp8_e8m0_scale.cpp delete mode 100644 example/67_gemm_microscaling/gemm_mx_fp8_fp16_scale.cpp delete mode 100644 example/67_gemm_microscaling/gemm_mx_fp8_fp8_scale.cpp create mode 100644 include/ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp diff --git a/CHANGELOG.md b/CHANGELOG.md index e3d7971c71..b9012c0a77 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added support for GKCYX layout for grouped convolution backward weight (NGCHW/GKCYX/NGKHW). * Added support for GKCYX layout for grouped convolution backward data (NGCHW/GKCYX/NGKHW). * Added support for Stream-K version of mixed fp8/bf16 GEMM +* Added GEMM pipeline for microscaling (MX) data types * Added support for FP16 2:4 structured sparsity to universal GEMM. ### Optimized diff --git a/example/67_gemm_microscaling/CMakeLists.txt b/example/67_gemm_microscaling/CMakeLists.txt index 9e95c3e007..93770684df 100644 --- a/example/67_gemm_microscaling/CMakeLists.txt +++ b/example/67_gemm_microscaling/CMakeLists.txt @@ -1,10 +1,5 @@ add_custom_target(example_gemm_mx) -add_example_executable(example_gemm_mx_fp8_e8m0_scale gemm_mx_fp8_e8m0_scale.cpp) -add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_e8m0_scale) +add_example_executable(example_gemm_mx_fp8 gemm_mx_fp8.cpp) +add_example_dependencies(example_gemm_mx example_gemm_mx_fp8) -add_example_executable(example_gemm_mx_fp8_fp8_scale gemm_mx_fp8_fp8_scale.cpp) -add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_fp8_scale) - -add_example_executable(example_gemm_mx_fp8_fp16_scale gemm_mx_fp8_fp16_scale.cpp) -add_example_dependencies(example_gemm_mx example_gemm_mx_fp8_fp16_scale) diff --git a/example/67_gemm_microscaling/README.md b/example/67_gemm_microscaling/README.md index 713902588d..57b6490eda 100644 --- a/example/67_gemm_microscaling/README.md +++ b/example/67_gemm_microscaling/README.md @@ -10,16 +10,16 @@ Custom verification parameters: # arg4: verbosity (0=no info, 1=verbose info) # arg5 to 10: M(128x), N(128x), K(64x), StrideA, StrideB, StrideC # arg11: KBatch -./bin/example_gemm_mx_fp8_e8m0_scale 1 1 0 1 +./bin/example_gemm_mx_fp8 1 1 0 1 ``` Custom tensor shapes: ```bash -./bin/example_gemm_mx_fp8_fp16_scale 1 2 1 0 128 128 64 -1 -1 -1 1 +./bin/example_gemm_mx_fp8 1 2 1 0 128 128 256 -1 -1 -1 1 ``` Default invocation: ```bash -# Implies: ./bin/example_gemm_mx_fp8_fp8_scale 1 2 0 0 -./bin/example_gemm_mx_fp8_fp8_scale +# Implies: ./bin/example_gemm_mx_fp8 1 2 0 0 +./bin/example_gemm_mx_fp8 ``` \ No newline at end of file diff --git a/example/67_gemm_microscaling/gemm_mx_common.hpp b/example/67_gemm_microscaling/gemm_mx_common.hpp index 9a05954c73..32ef975192 100644 --- a/example/67_gemm_microscaling/gemm_mx_common.hpp +++ b/example/67_gemm_microscaling/gemm_mx_common.hpp @@ -95,7 +95,7 @@ bool parse_cmd_args(int argc, << std::endl << "arg3: time kernel (0=no, 1=yes)" << std::endl << "arg4: verbosity (0=no info, 1=verbose info)" << std::endl - << "arg5 to 10: M(128x), N(128x), K(64x), StrideA, StrideB, StrideC" << std::endl + << "arg5 to 10: M(128x), N(128x), K(256x), StrideA, StrideB, StrideC" << std::endl << "arg11: KBatch" << std::endl; return false; } @@ -103,7 +103,8 @@ bool parse_cmd_args(int argc, return true; } -template + ck::index_t ScaleBlockSize> bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& config) { - static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; - static constexpr auto BlkGemmPSched = ck::BlockGemmPipelineScheduler::Intrawave; - static constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v1; - - static constexpr ck::index_t ScaleBlockSize = MXVectorSize; - - static constexpr ck::index_t KPerBlock = 64; - using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3< - ALayout, // ALayout - BLayout, // BLayout - CLayout, // CLayout - ADataType, // ADataType - XDataType, // AScaleDataType - BDataType, // BDataType - XDataType, // BScaleDataType - CDataType, // CDataType - AccDataType, // GemmAccDataType - CShuffleDataType, // CShuffleDataType - AElementOp, // AElementwiseOperation - BElementOp, // BElementwiseOperation - CElementOp, // CElementwiseOperation - GemmSpec, // GemmSpec - MXVectorSize, // ScaleBlockSize: Scaling block size - 256, // BlockSize: Thread block size - 128, // MPerBlock - 128, // NPerBlock - KPerBlock, // KPerBlock - 16, // AK1 - 16, // BK1 - 32, // MPerXDL - 32, // NPerXDL - 2, // MXdlPerWave - 2, // NXdlPerWave - S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 - S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // ABlockTransferSrcAccessOrder - 2, // ABlockTransferSrcVectorDim - 16, // ABlockTransferSrcScalarPerVector - 16, // ABlockTransferDstScalarPerVector_AK1 - false, // ABlockLdsExtraM - S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 - S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder - S<1, 0, 2>, // BBlockTransferSrcAccessOrder - 2, // BBlockTransferSrcVectorDim - 16, // BBlockTransferSrcScalarPerVector - 16, // BBlockTransferDstScalarPerVector_BK1 - false, // BBlockLdsExtraN - 1, // CShuffleMXdlPerWavePerShuffle - 1, // CShuffleNXdlPerWavePerShuffle - S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock - 8, // CShuffleBlockTransferScalarPerVector_NPerBlock - BlkGemmPSched, // BlkGemmPipeSched - BlkGemmPVer, // BlkGemmPipelineVer - ADataType, // ComputeTypeA - BDataType // ComputeTypeB - >; auto M = problem_size.M; auto N = problem_size.N; @@ -230,8 +175,8 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c auto Scale_Stride_AM = f_get_default_stride(M, K / ScaleBlockSize, -1, AScaleLayout{}); auto Scale_Stride_BN = f_get_default_stride(K / ScaleBlockSize, N, -1, BScaleLayout{}); - Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, AScaleLayout{})); - Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BScaleLayout{})); + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); Tensor a_m_k_scale(f_host_tensor_descriptor( M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{})); // scales for A @@ -428,8 +373,10 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c if(config.time_kernel) { - std::size_t flop = std::size_t(2) * M * N * K + - std::size_t(2) * M * N * K / ScaleBlockSize; // GEMM + A scale + B scale + // Output size(M*N) * [dot product(2K) + product of scales(K/ScaleBlockSize) + scaling of + // partial sums(K/ScaleBlockSize)] + // FLOPS = 2 * M * N * K + 2 * M * N * K / ScaleBlockSize + std::size_t flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N * K / ScaleBlockSize; std::size_t num_btype = sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N + sizeof(XDataType) * (M * K + K * N) / ScaleBlockSize; @@ -445,7 +392,8 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c return res_verified; } -template , // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 16, // ABlockTransferSrcScalarPerVector + 16, // ABlockTransferDstScalarPerVector_AK1 + false, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 16, // BBlockTransferSrcScalarPerVector + 16, // BBlockTransferDstScalarPerVector_BK1 + false, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 8, // CShuffleBlockTransferScalarPerVector_NPerBlock + BlkGemmPSched, // BlkGemmPipeSched + BlkGemmPVer, // BlkGemmPipelineVer + ADataType, // ComputeTypeA + BDataType // ComputeTypeB + >; + +int main(int argc, char* argv[]) +{ + return run_mx_gemm_example(argc, argv) + ? 0 + : -1; +} diff --git a/example/67_gemm_microscaling/gemm_mx_fp8_e8m0_scale.cpp b/example/67_gemm_microscaling/gemm_mx_fp8_e8m0_scale.cpp deleted file mode 100644 index 393f4a2ea7..0000000000 --- a/example/67_gemm_microscaling/gemm_mx_fp8_e8m0_scale.cpp +++ /dev/null @@ -1,42 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "gemm_mx_common.hpp" - -using ADataType = ck::f8_t; -using BDataType = ck::f8_t; - -using XDataType = ck::e8m0_bexp_t; - -using CDataType = ck::half_t; -using AccDataType = float; -using CShuffleDataType = CDataType; - -using ALayout = Row; -using BLayout = Col; -using CLayout = Row; - -using AElementOp = PassThrough; // elementwise transformation for A matrix -using BElementOp = PassThrough; // elementwise transformation for B matrix -using CElementOp = PassThrough; // elementwise transformation for C matrix - -constexpr ck::index_t mx_vector_size = 32; // scaling block size - -int main(int argc, char* argv[]) -{ - return run_mx_gemm_example(argc, argv) - ? 0 - : -1; -} diff --git a/example/67_gemm_microscaling/gemm_mx_fp8_fp16_scale.cpp b/example/67_gemm_microscaling/gemm_mx_fp8_fp16_scale.cpp deleted file mode 100644 index dd654a8f69..0000000000 --- a/example/67_gemm_microscaling/gemm_mx_fp8_fp16_scale.cpp +++ /dev/null @@ -1,42 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "gemm_mx_common.hpp" - -using ADataType = ck::f8_t; -using BDataType = ck::f8_t; - -using XDataType = ck::half_t; - -using CDataType = ck::half_t; -using AccDataType = float; -using CShuffleDataType = CDataType; - -using ALayout = Row; -using BLayout = Col; -using CLayout = Row; - -using AElementOp = PassThrough; // elementwise transformation for A matrix -using BElementOp = PassThrough; // elementwise transformation for B matrix -using CElementOp = PassThrough; // elementwise transformation for C matrix - -constexpr ck::index_t mx_vector_size = 32; // scaling block size - -int main(int argc, char* argv[]) -{ - return run_mx_gemm_example(argc, argv) - ? 0 - : -1; -} diff --git a/example/67_gemm_microscaling/gemm_mx_fp8_fp8_scale.cpp b/example/67_gemm_microscaling/gemm_mx_fp8_fp8_scale.cpp deleted file mode 100644 index c42d9783be..0000000000 --- a/example/67_gemm_microscaling/gemm_mx_fp8_fp8_scale.cpp +++ /dev/null @@ -1,42 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "gemm_mx_common.hpp" - -using ADataType = ck::f8_t; -using BDataType = ck::f8_t; - -using XDataType = ck::f8_t; - -using CDataType = ck::half_t; -using AccDataType = float; -using CShuffleDataType = CDataType; - -using ALayout = Row; -using BLayout = Col; -using CLayout = Row; - -using AElementOp = PassThrough; // elementwise transformation for A matrix -using BElementOp = PassThrough; // elementwise transformation for B matrix -using CElementOp = PassThrough; // elementwise transformation for C matrix - -constexpr ck::index_t mx_vector_size = 32; // scaling block size - -int main(int argc, char* argv[]) -{ - return run_mx_gemm_example(argc, argv) - ? 0 - : -1; -} diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp new file mode 100644 index 0000000000..ebe075b55d --- /dev/null +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_mx_pipeline_xdlops_base.hpp @@ -0,0 +1,363 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/common_header.hpp" +#include "ck/utility/blkgemmpipe_scheduler.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp" +#include "ck/tensor_description/tensor_adaptor.hpp" + +namespace ck { + +template +struct BlockwiseGemmXdlops_mx_pipeline_base +{ + using ComputeTypeA = ADataType; + using ComputeTypeB = BDataType; + using AccType = float; // for now only support V_MFMA_SCALE_F32 + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + + using ThisThreadBlock = ThisThreadBlock; + + // Hardcode to 64, as HIP-provided "warpSize" would return 32 on RDNA GPUs. + static constexpr index_t WaveSize = 64; + + static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0); + static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0); + static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2); + static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2); + + static constexpr auto xdlops_gemm = + XdlopsGemm{}; + + static constexpr index_t AMmaKStride = KPack; + static constexpr index_t BMmaKStride = KPack; + + //> store rows/cols into thread registers in chunks of 16 + //> e.g. [k0,...,k15,k64,...,k79] or [k0,...,k15,k32,...,k47] + static constexpr index_t KThreadChunk = 16; + + static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; + static constexpr index_t KRepeat = KPerThread / KPack; + static constexpr index_t KPerInnerLoop = KPack; + + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); + + using HotLoopInstList = + ck::BlockwiseGemmXdlops_pipeline_hotloop_inst; + + static_assert(KPerThread % KPack == 0, + "Wrong KPack setting; try increasing KPerThread or decreasing KPack"); + + StaticBufferTupleOfVector + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + + __device__ static auto GetWaveIdx() + { + const index_t thread_id = ThisThreadBlock::GetThreadId(); + + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } + + __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + + const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex(); + + return make_tuple(0, waveId_m, xdlops_a_idx[I1], KThreadChunk * xdlops_a_idx[I0]); + } + + __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_n = wave_idx[I1]; + + const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex(); + + return make_tuple(0, waveId_n, xdlops_b_idx[I1], KThreadChunk * xdlops_b_idx[I0]); + } + + template + __device__ static auto + CalculateCThreadOriginDataIndex(Number, Number, Number, Number) + { + const auto wave_idx = GetWaveIdx(); + + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; + + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); + + constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); + + const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; + const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; + + return make_tuple(c_thread_m, c_thread_n); + } + + using Tuple4 = decltype(CalculateAThreadOriginDataIndex()); + + /** + * @brief Constructor for BlockwiseGemmXdlops_mx_pipeline_base. + * + * This constructor initializes the thread copy objects for matrices A and B. + * It also performs several compile-time checks to ensure the correctness of the + * matrix tile descriptors. + * + * @param a_origin The origin data index for matrix A. + * @param b_origin The origin data index for matrix B. + * + * @note The constructor includes static assertions to ensure that: + * - The matrix tile descriptors for A and B are known at compile-time. + * - The number of threads in the thread block matches the product of MWaves, NWaves, and + * WaveSize. + * - The dimensions of the block are divisible by the product of the corresponding XDL and + * repeat dimensions. + */ + __host__ __device__ + BlockwiseGemmXdlops_mx_pipeline_base(Tuple4 a_origin = CalculateAThreadOriginDataIndex(), + Tuple4 b_origin = CalculateBThreadOriginDataIndex()) + : a_thread_copy_(a_origin), b_thread_copy_(b_origin) + { + static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(), + "wrong! Desc should be known at compile-time"); + + static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize, + "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n"); + + static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0, + "wrong!"); + } + + // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl' + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, N, M0, M1, M2)); + } + + // XDL output supporting C_xdl = A_xdl * B_xdl + __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, Number{}, I1, I1, M0, M1, M2, N)); + } + + // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl' + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2); + } + + // XDL output supporting C_xdl = A_xdl * B_xdl + __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2); + } + + __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2() + { + constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 = + make_naive_tensor_descriptor_packed(make_tuple(I1, + Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_block_desc_g_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n) + { + const auto M = c_grid_desc_m_n.GetLength(I0); + const auto N = c_grid_desc_m_n.GetLength(I1); + + const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); + + return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2); + } + + template + __host__ __device__ static constexpr auto + MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n) + { + const auto G = c_grid_desc_g_m_n.GetLength(I0); + const auto M = c_grid_desc_g_m_n.GetLength(I1); + const auto N = c_grid_desc_g_m_n.GetLength(I2); + + const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor( + c_grid_desc_g_m_n, + make_tuple(make_pass_through_transform(G), + make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3, 5>{}, Sequence<2, 4, 6>{})); + + return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2( + c_grid_desc_g_m0_n0_m1_n1_m2_n2); + } + + static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k; + static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k; + + protected: + // M1, N1 as double buffer index + // Read buffer + Compute buffer + // A[M0, M1, M2, KPack] + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, I1, Number{}, Number{}), + make_tuple( + Number{}, Number{}, Number{}, I1)); + + // B[N0, N1, N2, KPack] + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor( + make_tuple(Number{}, I1, Number{}, Number{}), + make_tuple( + Number{}, Number{}, Number{}, I1)); + + // C[M, N, NumRegXdlops] + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, xdlops_gemm.GetRegSizePerXdlops())); + + using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + A_K1, + A_K1>; + + using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence<0, 1, 2, 3>, + 3, + B_K1, + B_K1>; + + AThreadCopy a_thread_copy_; + BThreadCopy b_thread_copy_; +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_selector.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_selector.hpp index 24f6afc381..c1433659d6 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_selector.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_selector.hpp @@ -7,6 +7,35 @@ namespace ck { +/** + * @brief Define matrix data types that have hardware support for MX GEMMs + */ +template +static constexpr bool is_scale_mfma_data_type() +{ + return is_same_v || is_same_v || is_same_v || + is_same_v || is_same_v; +} + +/** + * @brief Define scale data types that have hardware support for MX GEMMs + */ +template +static constexpr bool is_scale_mfma_scale_type() +{ + return is_same_v; +} + +/** + * @brief Combination of data types that have hardware support for MX GEMMs + */ +template +static constexpr bool scale_mfma_hw_support() +{ + return is_scale_mfma_data_type() && is_scale_mfma_data_type() && + is_scale_mfma_scale_type() && is_scale_mfma_scale_type(); +} + template constexpr auto BlockGemmMXPipeline_Selector() { + + // Hardware MX GEMM pipeline if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) { return BlockwiseGemmXdlops_pipeline_v1_mx - : BlockwiseGemmXdlops_pipeline_base + : BlockwiseGemmXdlops_mx_pipeline_base { - using Base = BlockwiseGemmXdlops_pipeline_base; + + using Base = BlockwiseGemmXdlops_mx_pipeline_base; using Base::I0; using Base::I1; using Base::KRepeat; @@ -134,7 +125,6 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx How many mx-vectors in each row/col is processed in one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRun = (KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize; + + //> How many scales a thread must read to accommodate one call to xdlops_gemm.Run() + static constexpr auto ScalesPerXdlopsRunPerThread = + ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks; __host__ static constexpr bool BlockHasHotloop(index_t num_loop) { @@ -172,45 +173,6 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); auto a_scale_thread_buf = make_static_buffer( @@ -276,49 +238,31 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, xdlops_gemm.mfma_instr.num_groups_per_blk, 1>{}([&](auto g) { - auto a_scale_thread_buf_group = + static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + constexpr auto a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s)); + auto a_scale_thread_buf_copy = make_static_buffer( - a_scale_thread_desc_group.GetElementSpaceSize()); - + a_scale_thread_desc_copy.GetElementSpaceSize()); a_scale_thread_copy.Run(a_scale_grid_desc, a_scale_grid_buf, - a_scale_thread_desc_group, + a_scale_thread_desc_copy, make_tuple(I0, I0), - a_scale_thread_buf_group); + a_scale_thread_buf_copy); - static_for<0, xdlops_gemm.mfma_instr.group_size, 1>{}([&](auto i) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, g, i)); - a_scale_thread_buf(Number{}) = - a_scale_thread_buf_group[Number{}]; - }); - // go to the next group + a_scale_thread_buf(Number{}) = + a_scale_thread_buf_copy[Number<0>{}]; a_scale_thread_copy.MoveSrcSliceWindow( a_scale_grid_desc, - make_multi_index(2 * xdlops_gemm.mfma_instr.group_size, 0)); - }); // g - - // restore row id and advance to the next scale - a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, - make_multi_index(-2 * xdlops_gemm.mfma_instr.group_size * - xdlops_gemm.mfma_instr.num_groups_per_blk, - 1)); - }); // k0 - - // restore column id and advance to the next set of rows + make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); + }); + }); a_scale_thread_copy.MoveSrcSliceWindow( a_scale_grid_desc, make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize)); - }); // m0 + }); // restore row id and advance to the next set of scales a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc, @@ -326,15 +270,32 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx{}([&](auto n0) { - b_scale_thread_copy.Run(b_scale_grid_desc, - b_scale_grid_buf, - b_scale_thread_desc, - make_tuple(n0, I0), - b_scale_thread_buf); - b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, - make_multi_index(NWaves * NPerXDL, 0)); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + constexpr auto b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s)); + auto b_scale_thread_buf_copy = + make_static_buffer( + b_scale_thread_desc_copy.GetElementSpaceSize()); + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc_copy, + make_tuple(I0, I0), + b_scale_thread_buf_copy); + + b_scale_thread_buf(Number{}) = + b_scale_thread_buf_copy[Number<0>{}]; + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); + }); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize)); }); + // restore col id and advance to the next set of scales + // NWaves * NPerXDL * NRepeat == NPerBlock b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize)); @@ -345,8 +306,6 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx(); - // main body if constexpr(HasMainLoop) { @@ -363,141 +322,166 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx 15 32 --> 47 | 64 --> 79 96 --> 111 | etc. + // t32: |16 --> 31 48 --> 63 | 80 --> 95 112 --> 127 | etc. + // k = 0 k = 1 + + // k indexes mapping to threads for 16x16x128: + // t0 : |0 --> 15 64 --> 79 | 128 --> 143 192 --> 207| etc. + // t16: |16 --> 31 80 --> 95 | 144 --> 159 208 --> 223| etc. + // t32: |32 --> 47 96 --> 111| 160 --> 175 224 --> 239| etc. + // t48: |48 --> 63 112 --> 127| 176 --> 191 240 --> 255| etc. + // k = 0 k = 1 static_for<0, KRepeat, 1>{}([&](auto k) { - constexpr auto a_k_step = k * AMmaKStride * KPack / xdlops_gemm.K1PerXdlops; - constexpr auto b_k_step = k * BMmaKStride * KPack / xdlops_gemm.K1PerXdlops; + constexpr auto k_step = + k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops); static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_buf); + static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run( + a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, Number{}), + a_thread_buf); + }); }); static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_buf); + // read block data in chunks to assemble correct thread vectors + static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run( + b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, Number{}), + b_thread_buf); + }); }); }); static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - c_thread_buf_per_scale.Clear(); - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); - using mfma_input_type = - typename vector_type::type; + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); + constexpr index_t b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); + + static_assert(0 < ScalesPerXdlopsRunPerThread, + "Must have at least one scale per Xdlops per Thread."); + + vector_type + a_scale_thread_vec; + vector_type + b_scale_thread_vec; + + // Pack scale_thread_buf into scale_thread_vec + static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_buf[Number{}]; + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_buf[Number{}]; + }); + + using mfma_input_type_a = + typename vector_type::type; + using mfma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); // MFMA accumulation - // m = 1:MPerXDL - // n = 1:NPerXDL - // k = 1:KPack - // c(m,n) += a(m,k)*b(k,n) xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(I0)); - - // one scale per k0 - constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0)); - - static_for<0, xdlops_gemm.mfma_instr.num_groups_per_blk, 1>{}( - [&](auto g) { - static_for<0, xdlops_gemm.mfma_instr.group_size, 1>{}( - [&](auto r) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset( - make_tuple(m0, k0, g, r)); - - constexpr auto reg_offset = - g * xdlops_gemm.mfma_instr.group_size + r; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset( - make_tuple(m0, n0, reg_offset)); - - c_thread_buf(Number{}) += - c_thread_buf_per_scale[Number{}] * - type_convert( - b_scale_thread_buf[Number{}]) * - type_convert( - a_scale_thread_buf[Number{}]); - }); - }); + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); + // Prefetch a_scales static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, xdlops_gemm.mfma_instr.num_groups_per_blk, 1>{}([&](auto g) { - auto a_scale_thread_buf_group = + static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + constexpr auto a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, s)); + auto a_scale_thread_buf_copy = make_static_buffer( - a_scale_thread_desc_group.GetElementSpaceSize()); - + a_scale_thread_desc_copy.GetElementSpaceSize()); a_scale_thread_copy.Run(a_scale_grid_desc, a_scale_grid_buf, - a_scale_thread_desc_group, + a_scale_thread_desc_copy, make_tuple(I0, I0), - a_scale_thread_buf_group); + a_scale_thread_buf_copy); - static_for<0, xdlops_gemm.mfma_instr.group_size, 1>{}([&](auto r) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, g, r)); - a_scale_thread_buf(Number{}) = - a_scale_thread_buf_group[Number{}]; - }); - // go to the next group + a_scale_thread_buf(Number{}) = + a_scale_thread_buf_copy[Number<0>{}]; a_scale_thread_copy.MoveSrcSliceWindow( a_scale_grid_desc, - make_multi_index(2 * xdlops_gemm.mfma_instr.group_size, 0)); - }); // g - - // restore row id and advance to the next scale - a_scale_thread_copy.MoveSrcSliceWindow( - a_scale_grid_desc, - make_multi_index(-2 * xdlops_gemm.mfma_instr.group_size * - xdlops_gemm.mfma_instr.num_groups_per_blk, - 1)); - }); // k0 - - // restore column id and advance to the next set of rows + make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); + }); + }); a_scale_thread_copy.MoveSrcSliceWindow( a_scale_grid_desc, make_multi_index(MWaves * MPerXDL, -ScalesPerKBlockSize)); - }); // m0 + }); // restore row id and advance to the next set of scales a_scale_thread_copy.MoveSrcSliceWindow( a_scale_grid_desc, make_multi_index(-MPerBlock, ScalesPerKBlockSize)); + // Prefetch b_scales static_for<0, NRepeat, 1>{}([&](auto n0) { - b_scale_thread_copy.Run(b_scale_grid_desc, - b_scale_grid_buf, - b_scale_thread_desc, - make_tuple(n0, I0), - b_scale_thread_buf); - b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, - make_multi_index(NWaves * NPerXDL, 0)); + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + constexpr auto b_scale_offset = + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, s)); + auto b_scale_thread_buf_copy = + make_static_buffer( + b_scale_thread_desc_copy.GetElementSpaceSize()); + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc_copy, + make_tuple(I0, I0), + b_scale_thread_buf_copy); + + b_scale_thread_buf(Number{}) = + b_scale_thread_buf_copy[Number<0>{}]; + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(0, xdlops_gemm.KPerXdlops / ScaleBlockSize)); + }); + }); + b_scale_thread_copy.MoveSrcSliceWindow( + b_scale_grid_desc, + make_multi_index(NWaves * NPerXDL, -ScalesPerKBlockSize)); }); + + // restore col id and advance to the next set of scales // NWaves * NPerXDL * NRepeat == NPerBlock b_scale_thread_copy.MoveSrcSliceWindow( b_scale_grid_desc, make_multi_index(-NPerBlock, ScalesPerKBlockSize)); @@ -507,7 +491,6 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx{}([&](auto k) { - constexpr auto a_k_step = k * AMmaKStride * KPack / xdlops_gemm.K1PerXdlops; - constexpr auto b_k_step = k * BMmaKStride * KPack / xdlops_gemm.K1PerXdlops; + constexpr auto k_step = + k * xdlops_gemm.KPerXdlops * (KPack / xdlops_gemm.K1PerXdlops); static_for<0, MRepeat, 1>{}([&](auto m0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, - make_tuple(m0, I0, I0, Number{}), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, k, I0), - a_thread_buf); + // read block data in chunks to assemble correct thread + static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) { + constexpr auto a_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k, + make_tuple(m0, I0, I0, Number{}), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, k, Number{}), + a_thread_buf); + }); }); static_for<0, NRepeat, 1>{}([&](auto n0) { - b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, - make_tuple(n0, I0, I0, Number{}), - b_block_buf, - b_thread_desc_, - make_tuple(n0, I0, k, I0), - b_thread_buf); + // read block data in chunks to assemble correct thread + static_for<0, xdlops_gemm.K1PerXdlops / KThreadChunk, 1>{}([&](auto chunk) { + constexpr auto b_k_step_chunk = + k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks; + b_thread_copy_.Run(b_block_desc_n0_n1_n2_k, + make_tuple(n0, I0, I0, Number{}), + b_block_buf, + b_thread_desc_, + make_tuple(n0, I0, k, Number{}), + b_thread_buf); + }); }); }); static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - c_thread_buf_per_scale.Clear(); - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto ik) { - a_thread_vec.template AsType()(ik) = + a_thread_vec.template AsType()(ik) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(ik) = + b_thread_vec.template AsType()(ik) = b_thread_buf[Number{}]; }); - using mfma_input_type = - typename vector_type::type; + constexpr index_t a_scale_offset = + a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, I0)); - xdlops_gemm.template Run<>( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf_per_scale.GetVectorTypeReference(I0)); - - // one scale per k0 constexpr index_t b_scale_offset = - b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0)); + b_scale_thread_desc.CalculateOffset(make_tuple(n0, k0, I0)); - static_for<0, xdlops_gemm.mfma_instr.num_groups_per_blk, 1>{}([&](auto g) { - static_for<0, xdlops_gemm.mfma_instr.group_size, 1>{}([&](auto r) { - constexpr index_t a_scale_offset = - a_scale_thread_desc.CalculateOffset(make_tuple(m0, k0, g, r)); + vector_type a_scale_thread_vec; + vector_type b_scale_thread_vec; - constexpr auto reg_offset = - g * xdlops_gemm.mfma_instr.group_size + r; - - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, reg_offset)); - - c_thread_buf(Number{}) += - c_thread_buf_per_scale[Number{}] * - type_convert( - b_scale_thread_buf[Number{}]) * - type_convert( - a_scale_thread_buf[Number{}]); - }); + // Pack b_scale_thread_buf into b_scale_thread_vec + static_for<0, ScalesPerXdlopsRunPerThread, 1>{}([&](auto s) { + a_scale_thread_vec.template AsType()(s) = + a_scale_thread_buf[Number{}]; + b_scale_thread_vec.template AsType()(s) = + b_scale_thread_buf[Number{}]; }); + + using mfma_input_type_a = + typename vector_type::type; + using mfma_input_type_b = + typename vector_type::type; + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + // MFMA accumulation + xdlops_gemm.template Run<>( + a_thread_vec.template AsType(), + a_scale_thread_vec.template AsType(), + b_thread_vec.template AsType(), + b_scale_thread_vec.template AsType(), + c_thread_buf.GetVectorTypeReference(Number{})); }); }); }); } } - // TODO: make this field protected when a_scale_thread_copy_ is moved here + // TODO: make this field protected when a_scale_thread_copy_ is moved + // here static constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, - Number{}, - Number{}, - Number{})); + make_tuple(Number{}, Number{}, Number{})); // Is used to copy data from a_scale_grid to a_scale_thread - static constexpr auto a_scale_thread_desc_group = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number<1>{})); + static constexpr auto a_scale_thread_desc_copy = + make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{})); - // TODO: make this field protected when b_scale_thread_copy_ is moved here - static constexpr auto b_scale_thread_desc = - make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number{})); + // TODO: make this field protected when b_scale_thread_copy_ is moved + // here + static constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, Number{})); + + // Is used to copy data from b_scale_grid to b_scale_thread_buf + static constexpr auto b_scale_thread_desc_copy = + make_naive_tensor_descriptor_packed(make_tuple(Number<1>{}, Number<1>{})); protected: using Base::a_thread_copy_; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp index 34df9a1d7b..8a370304c6 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp @@ -694,14 +694,7 @@ struct DeviceGemmMX_Xdl_CShuffleV3 : public DeviceGemmMX || is_same_v || - is_same_v || is_same_v || - is_same_v)&&(is_same_v || - is_same_v || - is_same_v || - is_same_v || - is_same_v), + static_assert(is_scale_mfma_data_type() && is_scale_mfma_data_type(), "Only microscaling formats are supported for ADataType and BDataType"); static_assert(ScaleBlockSize == 32, "Only ScaleBlockSize 32 is supported"); @@ -711,6 +704,11 @@ struct DeviceGemmMX_Xdl_CShuffleV3 : public DeviceGemmMX{}; static constexpr auto BK1Number = Number{}; - static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); - static constexpr bool is_single_rate_mfma = - ((is_same::value || is_same::value) && - lcm_AK1_BK1 <= 4) - ? true - : false; + static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number); + static constexpr bool is_single_rate_mfma = false; + static constexpr auto is_scale_mfma = true; + + //> KPack is at least the k_per_blk of selected mfma + // + // Should be a multiple of k_per_blk. + // TODO: Move this to blockwise pipeline base static constexpr index_t KPack = math::max(lcm_AK1_BK1, - MfmaSelector:: - selected_mfma.k_per_blk); + MfmaSelector::selected_mfma.k_per_blk); using ThisThreadBlock = ThisThreadBlock; @@ -1088,10 +1094,6 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 static_assert(KPerBlock % ScaleBlockSize == 0, "KPerBlock should be multiple of ScaleBlockSize"); - static_assert(KPerBlock / ScaleBlockSize == BlockwiseGemmPipe::KRepeat, - "Single call to xdlops_gemm::Run should process exactly ScaleBlockSize " - "elements in k dimension"); - if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || @@ -1476,61 +1478,63 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / KPerBlock); - static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma; - static constexpr auto KPerXdlops = mfma.GetKPerXdlops(); - static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops(); - static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops; - static constexpr auto KPerThread = KPerBlock / K0PerXdlops; - - // NXdlPerWave == NRepeat - // MXdlPerWave == MRepeat - constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); - constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); - - // Initial thread mapping for MPerXdl=NPerXdl=32 and MPerBlock=NPerBlock=128 MWaves=NWaves=2 + // Initial thread mapping for: + // BlockSize = 256 + // MPerXdl=NPerXdl=32 and MPerBlock=NPerBlock=128 MRepeat=NRepeat=2 MWaves=NWaves=2 + // For each [m0, n0] tile, there are 4 waves: // tId in [ 0, 63] m x n = [ 0, 31] x [ 0, 31] waveId = [0, 0] // tId in [ 64, 127] m x n = [ 0, 31] x [32, 63] waveId = [0, 1] // tId in [128, 191] m x n = [32, 63] x [ 0, 31] waveId = [1, 0] // tId in [192, 255] m x n = [32, 63] x [32, 63] waveId = [1, 1] - auto a_thread_offset_m = - MPerXdl * ((get_thread_local_1d_id() / BlockwiseGemmPipe::WaveSize) / MWaves) + - mfma.selected_mfma.group_size * - ((get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize) / MPerXdl); - auto a_thread_offset_k = KPerThread * (get_thread_local_1d_id() % MPerXdl) / MPerXdl; + // BlockSize = 128 + // MPerXdl=NPerXdl=16 and MPerBlock=128 NPerBlock=16 MRepeat=4 NRepeat=1 MWaves=2 NWaves=1 + // For each [m0, n0] tile, there are 2 waves: + // tId in [ 0, 63] m x n = [ 0, 15] x [0, 15] waveId = [0, 0] + // tId in [ 64, 127] m x n = [16, 31] x [0, 15] waveId = [1, 0] - auto b_thread_offset_n = - get_thread_local_1d_id() % NPerXdl + - (get_thread_local_1d_id() / BlockwiseGemmPipe::WaveSize) % NWaves * NPerXdl; - auto b_thread_offset_k = KPerThread * (get_thread_local_1d_id() % NPerXdl) / NPerXdl; + // TODO: Document initial thread mapping for more combinations of parameters - auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< - AScaleDataType, - AScaleDataType, - decltype(a_scale_grid_desc_am_ak), // SrcDesc - decltype(BlockwiseGemmPipe::a_scale_thread_desc_group), // DstDesc - Sequence, // SliceLengths - Sequence<0, 1>, // DimAccessOrder - 0, // SrcVectorDim - 1, // SrcScalarPerVector - 1, // SrcScalarStrideInVector - true>(a_scale_grid_desc_am_ak, - make_multi_index(block_m_id * MPerBlock + a_thread_offset_m, - a_thread_offset_k / ScaleBlockSize)); + const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; - auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< - BScaleDataType, - BScaleDataType, - decltype(b_scale_grid_desc_bn_ak), - decltype(BlockwiseGemmPipe::b_scale_thread_desc), - Sequence<1, BlockwiseGemmPipe::KRepeat>, // SliceLengths - Sequence<0, 1>, // DimAccessOrder - 1, // SrcVectorDim - BlockwiseGemmPipe::KRepeat, // SrcScalarPerVector - 1, - false>(b_scale_grid_desc_bn_ak, - make_multi_index(block_n_id * NPerBlock + b_thread_offset_n, - b_thread_offset_k / ScaleBlockSize)); + static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma; + + auto thread_offset_k = (get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize) / + mfma.selected_mfma.num_threads_per_blk; + + auto a_thread_offset_m = get_thread_local_1d_id() % MPerXdl + waveId_m * MPerXdl; + + auto a_scale_thread_copy = + ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0, 1>, // DimAccessOrder + 1, // SrcVectorDim + 1, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>( + a_scale_grid_desc_am_ak, + make_multi_index(block_m_id * MPerBlock + a_thread_offset_m, thread_offset_k)); + + auto b_thread_offset_n = get_thread_local_1d_id() % NPerXdl + waveId_n * NPerXdl; + + auto b_scale_thread_copy = + ThreadwiseTensorSliceTransfer_v2, // SliceLengths + Sequence<0, 1>, // DimAccessOrder + 1, // SrcVectorDim + 1, // SrcScalarPerVector + 1, + true>( + b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock + b_thread_offset_n, thread_offset_k)); blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1, diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp index 0310fe37a0..2255505985 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp @@ -211,8 +211,7 @@ struct ThreadwiseTensorSliceTransfer_v1r3 * @tparam SrcVectorDim The dimension along which vectorized access is performed in the source * tensor. * @tparam SrcScalarPerVector The number of scalar elements per vector in the source tensor. - * @tparam SrcScalarStrideInVector The stride of scalar elements within a vector in the source - * tensor. + * @tparam SrcScalarStrideInVector Not used. * @tparam SrcResetCoordinateAfterRun controls whether source coordinate is restored after each Run * or rolled back one step in MoveSrcSliceWindow * @tparam InvalidElementAsNaN Whether to fill invalid elements with NaN (only applicable for diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index a638ca8608..529a1a1729 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -845,15 +845,24 @@ struct mfma_type static constexpr bool is_k_reduction = true; // ??? // clang-format on - template + template __device__ void run(const FloatA& a, - const int32_t scale_a, + const ScaleA& scale_a, const FloatB& b, - const int32_t scale_b, + const ScaleB& scale_b, FloatC& reg_c) const { + static_assert(scalar_type::vector_size == 1, "Expect single scale at this point."); + static_assert(scalar_type::vector_size == 1, "Expect single scale at this point."); + intrin_mfma_scale_f32_32x32x64f8f6f4::Run( - a, scale_a, b, scale_b, reg_c); + a, utils::get_exponent_value(scale_a), b, utils::get_exponent_value(scale_b), reg_c); } }; @@ -874,15 +883,24 @@ struct mfma_type static constexpr bool is_k_reduction = true; // ??? // clang-format on - template + template __device__ void run(const FloatA& a, - const int32_t scale_a, + const ScaleA& scale_a, const FloatB& b, - const int32_t scale_b, + const ScaleB& scale_b, FloatC& reg_c) const { + static_assert(scalar_type::vector_size == 1, "Expect single scale at this point."); + static_assert(scalar_type::vector_size == 1, "Expect single scale at this point."); + intrin_mfma_scale_f32_16x16x128f8f6f4::Run( - a, scale_a, b, scale_b, reg_c); + a, utils::get_exponent_value(scale_a), b, utils::get_exponent_value(scale_b), reg_c); } }; @@ -890,14 +908,16 @@ template + bool is_single_rate_mfma = false, + bool is_scale_mfma = false> struct MfmaSelector { template + bool is_single_rate_mfma_ = false, + bool is_scale_mfma_ = false> static constexpr auto GetMfma(); template <> @@ -1103,12 +1123,24 @@ struct MfmaSelector return MfmaInstr::mfma_f32_32x32x16f8f8; } + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_scale_f32_32x32x64f8f6f4; + } + template <> constexpr auto GetMfma() { return MfmaInstr::mfma_f32_16x16x32f8f8; } + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_scale_f32_16x16x128f8f6f4; + } + template <> constexpr auto GetMfma() { @@ -1145,8 +1177,12 @@ struct MfmaSelector return MfmaInstr::mfma_f32_16x16x32bf8f8; } - static constexpr auto selected_mfma = mfma_type< - GetMfma()>{}; + static constexpr auto selected_mfma = mfma_type()>{}; __host__ __device__ constexpr MfmaSelector() { @@ -1194,7 +1230,8 @@ template + bool TransposeC = false, + bool is_scale_mfma = false> struct XdlopsGemm { static constexpr auto I0 = Number<0>{}; @@ -1225,7 +1262,7 @@ struct XdlopsGemm MPerXdlops == 64, "Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops"); - static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack cannot be divided by k_per_blk"); + static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack should be a multiple of k_per_blk"); } // XDL output supporting C = A * B @@ -1368,6 +1405,27 @@ struct XdlopsGemm }); } + template + __device__ void Run(const FloatA& p_a_wave, + const ScaleA& a_scale_thread, + const FloatB& p_b_wave, + const ScaleB& b_scale_thread, + FloatC& p_c_thread) const + { + static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { + if constexpr(!TransposeC) + { + mfma_instr.template run( + p_a_wave[k], a_scale_thread[k], p_b_wave[k], b_scale_thread[k], p_c_thread); + } + else + { + mfma_instr.template run( + p_b_wave[k], b_scale_thread[k], p_a_wave[k], a_scale_thread[k], p_c_thread); + } + }); + } + __device__ static auto GetLaneId() { return get_thread_local_1d_id() % mfma_instr.wave_size; } __device__ static auto GetBlkIdx() @@ -1455,7 +1513,8 @@ struct XdlopsGemm KPack <= 4) || (is_same::value && KPack <= 8)) ? true - : false > {}; + : false, + is_scale_mfma > {}; static constexpr auto mfma_instr = mfma.selected_mfma; diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index 0d4611becc..a54a181bf1 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -520,9 +520,9 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32> { template __device__ static void Run(const f8x32_t& reg_a, - const int32_t scale_a, + const int32_t& scale_a, const f8x32_t& reg_b, - const int32_t scale_b, + const int32_t& scale_b, FloatC& reg_c) { #if defined(__gfx950__) @@ -538,6 +538,14 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32> scale_a, 0, // OPSEL scale_b); + // XXX: Note on the scale_a and scale_b parameters: + // If compiler detects that one or both scales are constant values, it will treat that + // constant as F32 constant. I.e., if scale_a at some point was declared as + // `e8m0_bexp_t a_scale{1.0f}`, the instruction would only work if scale_a parameter is + // assigned value `bit_cast(static_cast(a_scale))`. + + // XXX: Note on the OPSEL parameters: Instruction always takes byte0 as a scale value even + // when OPSEL is set otherwise. #else ignore = reg_a; ignore = scale_a; @@ -556,9 +564,9 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16> { template __device__ static void Run(const f8x32_t& reg_a, - const int32_t scale_a, + const int32_t& scale_a, const f8x32_t& reg_b, - const int32_t scale_b, + const int32_t& scale_b, FloatC& reg_c) { #if defined(__gfx950__) diff --git a/include/ck/utility/e8m0.hpp b/include/ck/utility/e8m0.hpp index a692f533f8..f7d2a2f594 100644 --- a/include/ck/utility/e8m0.hpp +++ b/include/ck/utility/e8m0.hpp @@ -67,10 +67,10 @@ struct e8m0_bexp_t namespace utils { template -__host__ __device__ inline int get_exponent_value(T x); +__host__ __device__ inline constexpr int32_t get_exponent_value(T x); template <> -__host__ __device__ inline int get_exponent_value(e8m0_bexp_t x) +__host__ __device__ inline constexpr int32_t get_exponent_value(e8m0_bexp_t x) { return x.data; } diff --git a/include/ck/utility/mxfp_utils.hpp b/include/ck/utility/mxfp_utils.hpp index f0a86f8750..cf7a3e8713 100644 --- a/include/ck/utility/mxfp_utils.hpp +++ b/include/ck/utility/mxfp_utils.hpp @@ -32,13 +32,13 @@ template __host__ __device__ inline bool is_inf(e8m0_bexp_t const scale, T const data); template -__host__ __device__ inline int get_exponent_value(T x) +__host__ __device__ inline constexpr int32_t get_exponent_value(T x) { x >>= NumericUtils::mant; x &= ((1 << NumericUtils::exp) - 1); - return static_cast(x); + return static_cast(x); } template diff --git a/test/mx_mfma_op/mx_mfma_op.hpp b/test/mx_mfma_op/mx_mfma_op.hpp index 1f9091ebc5..d22157c3b3 100644 --- a/test/mx_mfma_op/mx_mfma_op.hpp +++ b/test/mx_mfma_op/mx_mfma_op.hpp @@ -30,48 +30,69 @@ enum class MFMA_F8F6F4 }; -template +template struct mfma_type_selector; -template -struct mfma_type_selector +template <> +struct mfma_type_selector<16, 16> { - __device__ void operator()(AFragT const& fragA, BFragT const& fragB, AccumFragT& fragAcc) + template + __device__ static void run(AFragT const& fragA, BFragT const& fragB, AccumFragT& fragAcc) { auto op = mfma_type{}; - op.template run<16, 16, AFragT, BFragT, AccumFragT>(fragA, fragB, fragAcc); - } - - __device__ void operator()(AFragT const& fragA, - const int32_t scale_a, - BFragT const& fragB, - const int32_t scale_b, - AccumFragT& fragAcc) - { - auto op = mfma_type{}; - op.template run<16, 16, AFragT, BFragT, AccumFragT>( - fragA, scale_a, fragB, scale_b, fragAcc); + op.template run<16, 16>(fragA, fragB, fragAcc); } }; -template -struct mfma_type_selector +template <> +struct mfma_type_selector<32, 32> { - __device__ void operator()(AFragT const& fragA, BFragT const& fragB, AccumFragT& fragAcc) + template + __device__ static void run(AFragT const& fragA, BFragT const& fragB, AccumFragT& fragAcc) { auto op = mfma_type{}; - op.template run<32, 32, AFragT, BFragT, AccumFragT>(fragA, fragB, fragAcc); + op.template run<32, 32>(fragA, fragB, fragAcc); } +}; - __device__ void operator()(AFragT const& fragA, - const int32_t scale_a, +template +struct mfma_scale_type_selector; + +template <> +struct mfma_scale_type_selector<16, 16> +{ + template + __device__ static void run(AFragT const& fragA, + AScaleFragT const& scale_a, BFragT const& fragB, - const int32_t scale_b, + BScaleFragT const& scale_b, + AccumFragT& fragAcc) + { + auto op = mfma_type{}; + op.template run<16, 16>(fragA, scale_a[Number<0>{}], fragB, scale_b[Number<0>{}], fragAcc); + } +}; + +template <> +struct mfma_scale_type_selector<32, 32> +{ + template + __device__ static void run(AFragT const& fragA, + AScaleFragT const& scale_a, + BFragT const& fragB, + BScaleFragT const& scale_b, AccumFragT& fragAcc) { auto op = mfma_type{}; - op.template run<32, 32, AFragT, BFragT, AccumFragT>( - fragA, scale_a, fragB, scale_b, fragAcc); + op.template run<32, 32>(fragA, scale_a[Number<0>{}], fragB, scale_b[Number<0>{}], fragAcc); } }; @@ -334,8 +355,7 @@ __device__ AFragT load_mx_A_row_major(AType const* input_ptr, // BLOCK_K / BLOCK_X is a stride in xA matrix auto startOffset = row_major(startCoord2D, BLOCK_K / BLOCK_X); - // obtain 8-bit exponent - fragX = utils::get_exponent_value(scale_ptr[startOffset]) & 0xFF; + fragX = scale_ptr[startOffset]; return load_A_row_major(input_ptr); } @@ -502,7 +522,7 @@ __device__ BFragT load_mx_B_col_major(BType const* input_ptr, auto startOffset = col_major(startCoord2D, BLOCK_K / BLOCK_X); // obtain 8-bit exponent - fragX = utils::get_exponent_value(scale_ptr[startOffset]) & 0xFF; + fragX = scale_ptr[startOffset]; return load_B_col_major(input_ptr); } @@ -773,7 +793,8 @@ __global__ void matmul(const AType* a, const BType* b, CType* c) // Matrix multiply-accumulate using MFMA units // Accumulation intermediate = BLOCK_M x BLOCK_N - mfma_type_selector{}(fragA, fragB, fragAcc); + using mfma = mfma_type_selector; + mfma::template run<>(fragA, fragB, fragAcc); for(int i = 0; i < vectorSize(fragC); ++i) { @@ -805,29 +826,34 @@ matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb, using CFragT = vector_type::type; using AccumFragT = vector_type; using RawAccumFragT = vector_type::type; - using ScaleFragT = int32_t; + using AScaleFragT = vector_type::type; + using BScaleFragT = vector_type::type; // Create frags auto fragA = AFragT{}; auto fragB = BFragT{}; auto fragC = CFragT{}; auto fragAcc = AccumFragT{0}; - auto fragXa = ScaleFragT{0}; - auto fragXb = ScaleFragT{0}; + auto fragXa = AScaleFragT{}; + auto fragXb = BScaleFragT{}; // Load the inputs. // A = col major, BLOCK_M x BLOCK_K - fragA = load_mx_A_row_major( + fragA = load_mx_A_row_major( a, xa, fragXa); // B = col major, BLOCK_K x BLOCK_N - fragB = load_mx_B_col_major( + fragB = load_mx_B_col_major( b, xb, fragXb); // Scaled Matrix multiply-accumulate using MFMA units // Accumulation intermediate = BLOCK_M x BLOCK_N - mfma_type_selector{}( - fragA, fragXa, fragB, fragXb, fragAcc); + using mfma = mfma_scale_type_selector; + mfma::template run<>(fragA, + fragXa.template AsType(), + fragB, + fragXb.template AsType(), + fragAcc); for(int i = 0; i < vectorSize(fragC); ++i) {