diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index 4f20487b9b..8c0b950941 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -780,7 +780,6 @@ struct mfma_type } }; -// TODO: fix mfma...f8f6f4 instructions template <> struct mfma_type { @@ -847,9 +846,14 @@ struct mfma_type // clang-format on template - __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + __device__ void run(const FloatA& a, + const int32_t scale_a, + const FloatB& b, + const int32_t scale_b, + FloatC& reg_c) const { - intrin_mfma_scale_f32_32x32x64f8f6f4::Run(a, b, reg_c); + intrin_mfma_scale_f32_32x32x64f8f6f4::Run( + a, scale_a, b, scale_b, reg_c); } }; @@ -871,9 +875,14 @@ struct mfma_type // clang-format on template - __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + __device__ void run(const FloatA& a, + const int32_t scale_a, + const FloatB& b, + const int32_t scale_b, + FloatC& reg_c) const { - intrin_mfma_scale_f32_16x16x128f8f6f4::Run(a, b, reg_c); + intrin_mfma_scale_f32_16x16x128f8f6f4::Run( + a, scale_a, b, scale_b, reg_c); } }; diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index b125e3adf6..010b7aabd3 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -533,9 +533,9 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32> reg_c.template AsType()[Number<0>{}], 0, // cbsz 0, // blgp - 0, // { OPSEL_HI[0], OPSEL[0] }? + 0, // OPSEL scale_a, - 0, // { OPSEL_HI[1], OPSEL[1] }? + 0, // OPSEL scale_b); #else ignore = reg_a; @@ -569,9 +569,9 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16> reg_c.template AsType()[Number<0>{}], 0, // cbsz 0, // blgp - 0, // { OPSEL_HI[0], OPSEL[0] }? + 0, // OPSEL scale_a, - 0, // { OPSEL_HI[1], OPSEL[1] }? + 0, // OPSEL scale_b); #else ignore = reg_a; diff --git a/test/mx_mfma_op/mx_mfma_op.cpp b/test/mx_mfma_op/mx_mfma_op.cpp index cc612794f4..f65e89bb82 100644 --- a/test/mx_mfma_op/mx_mfma_op.cpp +++ b/test/mx_mfma_op/mx_mfma_op.cpp @@ -30,11 +30,11 @@ bool run_mfma_test(ck::index_t init) constexpr auto BLOCK_N = mfma_instr.n_per_blk; constexpr auto BLOCK_K = mfma_instr.num_input_blks * mfma_instr.k_per_blk; - const auto mx_mfma_kernel = ck::matmul; + const auto mfma_kernel = ck::matmul; bool pass = true; - pass = ck::mfma_test::TestMFMA{}(mx_mfma_kernel, init); + BLOCK_K>{}(mfma_kernel, init); return pass; } TEST(MFMA, FP8MFMA16x16x128) { - auto AB_init = 0; + auto AB_init = 4; auto pass = run_mfma_test(AB_init); EXPECT_TRUE(pass); } TEST(MFMA, FP8MFMA32x32x64) { - auto AB_init = 0; + auto AB_init = 4; auto pass = run_mfma_test(AB_init); EXPECT_TRUE(pass); } + +/** + * @brief Run the test for the given MX MFMA instruction + * + * @param init - selects initialization algorithm for A and B tensors + */ +template +bool run_mxmfma_test(ck::index_t init) +{ + static_assert(mfma == ck::MFMA_F8F6F4::SCALE_F32_16x16x128 || + mfma == ck::MFMA_F8F6F4::SCALE_F32_32x32x64, + "Only SCALE_F32_16x16x128 and SCALE_F32_32x32x64 are supported"); + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + using AccType = float; // only MFMA_F32 instructions supported + using ScaleType = ck::e8m0_bexp_t; // biased exponent type + + ck::mfma_type(mfma)> mfma_instr; + constexpr auto BLOCK_M = mfma_instr.m_per_blk; + constexpr auto BLOCK_N = mfma_instr.n_per_blk; + constexpr auto BLOCK_K = mfma_instr.num_input_blks * mfma_instr.k_per_blk; + constexpr auto BLOCK_X = 32; // scaling vector size + + const auto mx_mfma_kernel = + ck::matmul; + + bool pass = true; + + pass = ck::mxmfma_test::TestMXMFMA{}(mx_mfma_kernel, init); + + return pass; +} + +TEST(MXMFMA, MXFP8MFMA16x16x128) +{ + auto AB_init = 7; + auto pass = run_mxmfma_test(AB_init); + EXPECT_TRUE(pass); +} + +TEST(MXMFMA, MXFP8MFMA32x32x64) +{ + auto AB_init = 7; + auto pass = run_mxmfma_test(AB_init); + EXPECT_TRUE(pass); +} diff --git a/test/mx_mfma_op/mx_mfma_op.hpp b/test/mx_mfma_op/mx_mfma_op.hpp index e96e1b0b29..1f9091ebc5 100644 --- a/test/mx_mfma_op/mx_mfma_op.hpp +++ b/test/mx_mfma_op/mx_mfma_op.hpp @@ -1,3 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + #pragma once #include "ck/ck.hpp" @@ -7,7 +10,7 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/warp/xdlops_gemm.hpp" #include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp" #include "ck/library/utility/check_err.hpp" namespace ck { @@ -18,7 +21,13 @@ enum class MFMA_F8F6F4 F32_16x16x128 = static_cast(MfmaInstr::mfma_f32_16x16x128f8f6f4), // V_MFMA_F32_16X16X128_F8F6F4 F32_32x32x64 = - static_cast(MfmaInstr::mfma_f32_32x32x64f8f6f4) // V_MFMA_F32_32X32X64_F8F6F4 + static_cast(MfmaInstr::mfma_f32_32x32x64f8f6f4), // V_MFMA_F32_32X32X64_F8F6F4 + + SCALE_F32_16x16x128 = static_cast( + MfmaInstr::mfma_scale_f32_16x16x128f8f6f4), // V_MFMA_SCALE_F32_16X16X128_F8F6F4 + SCALE_F32_32x32x64 = static_cast( + MfmaInstr::mfma_scale_f32_32x32x64f8f6f4) // V_MFMA_SCALE_F32_32X32X64_F8F6F4 + }; template @@ -32,6 +41,17 @@ struct mfma_type_selector auto op = mfma_type{}; op.template run<16, 16, AFragT, BFragT, AccumFragT>(fragA, fragB, fragAcc); } + + __device__ void operator()(AFragT const& fragA, + const int32_t scale_a, + BFragT const& fragB, + const int32_t scale_b, + AccumFragT& fragAcc) + { + auto op = mfma_type{}; + op.template run<16, 16, AFragT, BFragT, AccumFragT>( + fragA, scale_a, fragB, scale_b, fragAcc); + } }; template @@ -42,6 +62,17 @@ struct mfma_type_selector auto op = mfma_type{}; op.template run<32, 32, AFragT, BFragT, AccumFragT>(fragA, fragB, fragAcc); } + + __device__ void operator()(AFragT const& fragA, + const int32_t scale_a, + BFragT const& fragB, + const int32_t scale_b, + AccumFragT& fragAcc) + { + auto op = mfma_type{}; + op.template run<32, 32, AFragT, BFragT, AccumFragT>( + fragA, scale_a, fragB, scale_b, fragAcc); + } }; template @@ -52,151 +83,428 @@ static constexpr int32_t vectorSize(const VecT&) // Define a load function for input A blocks: // Size: (BLOCK_M x BLOCK_K) -// ASSUMPTION: -// - We want contiguous BLOCK_M sized column neighbors in register. -// - Data is in col_major format -// This means: -// - From A we will load K columns of size BLOCK_M to satisfy our input data +// - Data is in column major format +// - Rows are loaded in contiguous chunks that map to corresponding microscales +// - Each row is loaded in chunks of size 16 and each thread loads 32 elements template __device__ AFragT load_A_col_major(AType const* input_ptr) { // clang-format off // Register Mapping for 16x128: || Register Mapping for 32x64: - // Size | BLOCK_M | BLOCK_M | BLOCK_M | BLOCK_M | || Size | BLOCK_M | BLOCK_M | - // M | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | || M | 0 ... 31 | 0 ... 31 | - // Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Vector || Thread Id | 0 ... 31 | 32 ... 63 | Vector - // Register Element ------------ ------------- ------------ ------------- Element || Register Element ------------ ------------- Element - // Reg 0 [0:7] | K0 | K32 | K64 | K96 | v[0] || Reg 0 [0:7] | K0 | K32 | v[0] - // Reg 0 [8:15] | K1 | K33 | K65 | K97 | v[1] || Reg 0 [8:15] | K1 | K33 | v[1] - // Reg 0 [16:23] | K2 | K34 | K66 | K98 | v[2] || Reg 0 [16:23] | K2 | K34 | v[2] - // Reg 0 [24:31] | K3 | K35 | K67 | K99 | v[3] || Reg 0 [24:31] | K3 | K35 | v[3] - // Reg 1 [0:7] | K4 | K36 | K68 | K100 | v[4] || Reg 1 [0:7] | K4 | K36 | v[4] - // Reg 1 [8:15] | K5 | K37 | K69 | K101 | v[5] || Reg 1 [8:15] | K5 | K37 | v[5] - // Reg 1 [16:23] | K6 | K38 | K70 | K102 | v[6] || Reg 1 [16:23] | K6 | K38 | v[6] - // Reg 1 [24:31] | K7 | K39 | K71 | K103 | v[7] || Reg 1 [24:31] | K7 | K39 | v[7] - // Reg 2 [0:7] | K8 | K40 | K72 | K104 | v[8] || Reg 2 [0:7] | K8 | K40 | v[8] - // Reg 2 [8:15] | K9 | K41 | K73 | K105 | v[9] || Reg 2 [8:15] | K9 | K41 | v[9] - // Reg 2 [16:23] | K10 | K42 | K74 | K106 | v[10] || Reg 2 [16:23] | K10 | K42 | v[10] - // Reg 2 [24:31] | K11 | K43 | K75 | K107 | v[11] || Reg 2 [24:31] | K11 | K43 | v[11] - // Reg 3 [0:7] | K12 | K44 | K76 | K108 | v[12] || Reg 3 [0:7] | K12 | K44 | v[12] - // Reg 3 [8:15] | K13 | K45 | K77 | K109 | v[13] || Reg 3 [8:15] | K13 | K45 | v[13] - // Reg 3 [16:23] | K14 | K46 | K78 | K110 | v[14] || Reg 3 [16:23] | K14 | K46 | v[14] - // Reg 3 [24:31] | K15 | K47 | K79 | K111 | v[15] || Reg 3 [24:31] | K15 | K47 | v[15] - // Reg 4 [0:7] | K16 | K48 | K80 | K112 | v[16] || Reg 4 [0:7] | K16 | K48 | v[16] - // Reg 4 [8:15] | K17 | K49 | K81 | K113 | v[17] || Reg 4 [8:15] | K17 | K49 | v[17] - // Reg 4 [16:23] | K18 | K50 | K82 | K114 | v[18] || Reg 4 [16:23] | K18 | K50 | v[18] - // Reg 4 [24:31] | K19 | K51 | K83 | K115 | v[19] || Reg 4 [24:31] | K19 | K51 | v[19] - // Reg 5 [0:7] | K20 | K52 | K84 | K116 | v[20] || Reg 5 [0:7] | K20 | K52 | v[20] - // Reg 5 [8:15] | K21 | K53 | K85 | K117 | v[21] || Reg 5 [8:15] | K21 | K53 | v[21] - // Reg 5 [16:23] | K22 | K54 | K86 | K118 | v[22] || Reg 5 [16:23] | K22 | K54 | v[22] - // Reg 5 [24:31] | K23 | K55 | K87 | K119 | v[23] || Reg 5 [24:31] | K23 | K55 | v[23] - // Reg 6 [0:7] | K24 | K56 | K88 | K120 | v[24] || Reg 6 [0:7] | K24 | K56 | v[24] - // Reg 6 [8:15] | K25 | K57 | K89 | K121 | v[25] || Reg 6 [8:15] | K25 | K57 | v[25] - // Reg 6 [16:23] | K26 | K58 | K90 | K122 | v[26] || Reg 6 [16:23] | K26 | K58 | v[26] - // Reg 6 [24:31] | K27 | K59 | K91 | K123 | v[27] || Reg 6 [24:31] | K27 | K59 | v[27] - // Reg 7 [0:7] | K28 | K60 | K92 | K124 | v[28] || Reg 7 [0:7] | K28 | K60 | v[28] - // Reg 7 [8:15] | K29 | K61 | K93 | K125 | v[29] || Reg 7 [8:15] | K29 | K61 | v[29] - // Reg 7 [16:23] | K30 | K62 | K94 | K126 | v[30] || Reg 7 [16:23] | K30 | K62 | v[30] - // Reg 7 [24:31] | K31 | K63 | K95 | K127 | v[31] || Reg 7 [24:31] | K31 | K63 | v[31] + // Size | BLOCK_M | BLOCK_M | BLOCK_M | BLOCK_M | || Size | BLOCK_M | BLOCK_M | | + // M | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || M | 0 ... 31 | 0 ... 31 | Vector | + // Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element| + // Register Element |------------|-------------|------------|-------------|-----------|| Register Element |------------|-------------|--------| + // Reg 0 [0:7] | K0 | K16 | K32 | K48 | v[0] || Reg 0 [0:7] | K0 | K16 | v[0] | + // Reg 0 [8:15] | K1 | K17 | K33 | K49 | v[1] || Reg 0 [8:15] | K1 | K17 | v[1] | + // Reg 0 [16:23] | K2 | K18 | K34 | K50 | v[2] || Reg 0 [16:23] | K2 | K18 | v[2] | + // Reg 0 [24:31] | K3 | K19 | K35 | K51 | v[3] || Reg 0 [24:31] | K3 | K19 | v[3] | + // Reg 1 [0:7] | K4 | K20 | K36 | K52 | v[4] || Reg 1 [0:7] | K4 | K20 | v[4] | + // Reg 1 [8:15] | K5 | K21 | K37 | K53 | v[5] || Reg 1 [8:15] | K5 | K21 | v[5] | + // Reg 1 [16:23] | K6 | K22 | K38 | K54 | v[6] || Reg 1 [16:23] | K6 | K22 | v[6] | + // Reg 1 [24:31] | K7 | K23 | K39 | K55 | v[7] || Reg 1 [24:31] | K7 | K23 | v[7] | + // Reg 2 [0:7] | K8 | K24 | K40 | K56 | v[8] || Reg 2 [0:7] | K8 | K24 | v[8] | + // Reg 2 [8:15] | K9 | K25 | K41 | K57 | v[9] || Reg 2 [8:15] | K9 | K25 | v[9] | + // Reg 2 [16:23] | K10 | K26 | K42 | K58 | v[10] || Reg 2 [16:23] | K10 | K26 | v[10] | + // Reg 2 [24:31] | K11 | K27 | K43 | K59 | v[11] || Reg 2 [24:31] | K11 | K27 | v[11] | + // Reg 3 [0:7] | K12 | K28 | K44 | K60 | v[12] || Reg 3 [0:7] | K12 | K28 | v[12] | + // Reg 3 [8:15] | K13 | K29 | K45 | K61 | v[13] || Reg 3 [8:15] | K13 | K29 | v[13] | + // Reg 3 [16:23] | K14 | K30 | K46 | K62 | v[14] || Reg 3 [16:23] | K14 | K30 | v[14] | + // Reg 3 [24:31] | K15 | K31 | K47 | K63 | v[15] || Reg 3 [24:31] | K15 | K31 | v[15] | + // Reg 4 [0:7] | K64 | K80 | K96 | K112 | v[16] || Reg 4 [0:7] | K32 | K48 | v[16] | + // Reg 4 [8:15] | K65 | K81 | K97 | K113 | v[17] || Reg 4 [8:15] | K33 | K49 | v[17] | + // Reg 4 [16:23] | K66 | K82 | K98 | K114 | v[18] || Reg 4 [16:23] | K34 | K50 | v[18] | + // Reg 4 [24:31] | K67 | K83 | K99 | K115 | v[19] || Reg 4 [24:31] | K35 | K51 | v[19] | + // Reg 5 [0:7] | K68 | K84 | K100 | K116 | v[20] || Reg 5 [0:7] | K36 | K52 | v[20] | + // Reg 5 [8:15] | K69 | K85 | K101 | K117 | v[21] || Reg 5 [8:15] | K37 | K53 | v[21] | + // Reg 5 [16:23] | K70 | K86 | K102 | K118 | v[22] || Reg 5 [16:23] | K38 | K54 | v[22] | + // Reg 5 [24:31] | K71 | K87 | K103 | K119 | v[23] || Reg 5 [24:31] | K39 | K55 | v[23] | + // Reg 6 [0:7] | K72 | K88 | K104 | K120 | v[24] || Reg 6 [0:7] | K40 | K56 | v[24] | + // Reg 6 [8:15] | K73 | K89 | K105 | K121 | v[25] || Reg 6 [8:15] | K41 | K57 | v[25] | + // Reg 6 [16:23] | K74 | K90 | K106 | K122 | v[26] || Reg 6 [16:23] | K42 | K58 | v[26] | + // Reg 6 [24:31] | K75 | K91 | K107 | K123 | v[27] || Reg 6 [24:31] | K43 | K59 | v[27] | + // Reg 7 [0:7] | K76 | K92 | K108 | K124 | v[28] || Reg 7 [0:7] | K44 | K60 | v[28] | + // Reg 7 [8:15] | K77 | K93 | K109 | K125 | v[29] || Reg 7 [8:15] | K45 | K61 | v[29] | + // Reg 7 [16:23] | K78 | K94 | K110 | K126 | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] | + // Reg 7 [24:31] | K79 | K95 | K111 | K127 | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] | // clang-format on - // Here we want to load a BLOCK_M x BLOCK_K block of data. - static constexpr uint32_t VW = vectorSize(AFragT{}); - using ARawT = typename scalar_type::type; - using AScalarFragT = vector_type::type; + static constexpr int32_t WAVE_SIZE = 64; + + // Here we want to load from rows of A in chunks of 16 elements each. + static constexpr uint32_t chunk_size = 16; + + // each chunk is separated by offset + static constexpr uint32_t chunk_offset = chunk_size * WAVE_SIZE / BLOCK_M; // To start the loading process, let's visualize in 2D coords. // Each thread will load 32 elements. // We need to know where they start, and where the next elements are. - auto startCoord2D = std::make_pair(threadIdx.x % BLOCK_M, // Row - (threadIdx.x / BLOCK_M) * VW); // Col - auto stepCoord2D = std::make_pair(0u, 1u); + auto startCoord2D = + std::make_pair(threadIdx.x % BLOCK_M, // Row {0-31} | {0-15} + (threadIdx.x / BLOCK_M) * chunk_size); // Col {0, 16} | {0, 16, 32, 48} + + auto minorStepCoord2D = std::make_pair(0u, 1u); // read rows + auto majorStepCoord2D = std::make_pair(0, chunk_offset); // read a chunk from a row // Flatten to 1D col_major offsets. auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; }; // BLOCK_M is a stride in A matrix - auto startOffset = col_major(startCoord2D, BLOCK_M); - auto kOffset = col_major(stepCoord2D, BLOCK_M); + auto startOffset = col_major(startCoord2D, BLOCK_M); + auto kMinorOffset = col_major(minorStepCoord2D, BLOCK_M); + auto kMajorOffset = col_major(majorStepCoord2D, BLOCK_M); - // kOffset == BLOCK_M - // This means every BLOCK_M element is loaded into output vector - auto fragA = AScalarFragT{}; -#pragma unroll VW - for(uint32_t i = 0; i < VW; i++) + using ARawT = typename scalar_type::type; + using AScalarFragT = vector_type::type; + + AScalarFragT fragA{}; + +#pragma unroll + for(int chunk = 0; chunk < 2; chunk++) { - fragA[i] = bit_cast(input_ptr[startOffset + i * kOffset]); +#pragma unroll + for(uint32_t i = 0; i < chunk_size; i++) + { + fragA[chunk * chunk_size + i] = + bit_cast(input_ptr[startOffset + chunk * kMajorOffset + i * kMinorOffset]); + } } return fragA; } +// Define a load function for input A blocks: +// Size: (BLOCK_M x BLOCK_K) +// - Data is in row major format +// - Rows are loaded in contiguous chunks that map to corresponding microscales +// - Each row is loaded in chunks of size 16 and each thread loads 32 elements +template +__device__ AFragT load_A_row_major(AType const* input_ptr) +{ + // clang-format off + // Register Mapping for 16x128: || Register Mapping for 32x64: + // Size | BLOCK_M | BLOCK_M | BLOCK_M | BLOCK_M | || Size | BLOCK_M | BLOCK_M | | + // M | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || M | 0 ... 31 | 0 ... 31 | Vector | + // Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element| + // Register Element |------------|-------------|------------|-------------|-----------|| Register Element |------------|-------------|--------| + // Reg 0 [0:7] | K0 | K16 | K32 | K48 | v[0] || Reg 0 [0:7] | K0 | K16 | v[0] | + // Reg 0 [8:15] | K1 | K17 | K33 | K49 | v[1] || Reg 0 [8:15] | K1 | K17 | v[1] | + // Reg 0 [16:23] | K2 | K18 | K34 | K50 | v[2] || Reg 0 [16:23] | K2 | K18 | v[2] | + // Reg 0 [24:31] | K3 | K19 | K35 | K51 | v[3] || Reg 0 [24:31] | K3 | K19 | v[3] | + // Reg 1 [0:7] | K4 | K20 | K36 | K52 | v[4] || Reg 1 [0:7] | K4 | K20 | v[4] | + // Reg 1 [8:15] | K5 | K21 | K37 | K53 | v[5] || Reg 1 [8:15] | K5 | K21 | v[5] | + // Reg 1 [16:23] | K6 | K22 | K38 | K54 | v[6] || Reg 1 [16:23] | K6 | K22 | v[6] | + // Reg 1 [24:31] | K7 | K23 | K39 | K55 | v[7] || Reg 1 [24:31] | K7 | K23 | v[7] | + // Reg 2 [0:7] | K8 | K24 | K40 | K56 | v[8] || Reg 2 [0:7] | K8 | K24 | v[8] | + // Reg 2 [8:15] | K9 | K25 | K41 | K57 | v[9] || Reg 2 [8:15] | K9 | K25 | v[9] | + // Reg 2 [16:23] | K10 | K26 | K42 | K58 | v[10] || Reg 2 [16:23] | K10 | K26 | v[10] | + // Reg 2 [24:31] | K11 | K27 | K43 | K59 | v[11] || Reg 2 [24:31] | K11 | K27 | v[11] | + // Reg 3 [0:7] | K12 | K28 | K44 | K60 | v[12] || Reg 3 [0:7] | K12 | K28 | v[12] | + // Reg 3 [8:15] | K13 | K29 | K45 | K61 | v[13] || Reg 3 [8:15] | K13 | K29 | v[13] | + // Reg 3 [16:23] | K14 | K30 | K46 | K62 | v[14] || Reg 3 [16:23] | K14 | K30 | v[14] | + // Reg 3 [24:31] | K15 | K31 | K47 | K63 | v[15] || Reg 3 [24:31] | K15 | K31 | v[15] | + // Reg 4 [0:7] | K64 | K80 | K96 | K112 | v[16] || Reg 4 [0:7] | K32 | K48 | v[16] | + // Reg 4 [8:15] | K65 | K81 | K97 | K113 | v[17] || Reg 4 [8:15] | K33 | K49 | v[17] | + // Reg 4 [16:23] | K66 | K82 | K98 | K114 | v[18] || Reg 4 [16:23] | K34 | K50 | v[18] | + // Reg 4 [24:31] | K67 | K83 | K99 | K115 | v[19] || Reg 4 [24:31] | K35 | K51 | v[19] | + // Reg 5 [0:7] | K68 | K84 | K100 | K116 | v[20] || Reg 5 [0:7] | K36 | K52 | v[20] | + // Reg 5 [8:15] | K69 | K85 | K101 | K117 | v[21] || Reg 5 [8:15] | K37 | K53 | v[21] | + // Reg 5 [16:23] | K70 | K86 | K102 | K118 | v[22] || Reg 5 [16:23] | K38 | K54 | v[22] | + // Reg 5 [24:31] | K71 | K87 | K103 | K119 | v[23] || Reg 5 [24:31] | K39 | K55 | v[23] | + // Reg 6 [0:7] | K72 | K88 | K104 | K120 | v[24] || Reg 6 [0:7] | K40 | K56 | v[24] | + // Reg 6 [8:15] | K73 | K89 | K105 | K121 | v[25] || Reg 6 [8:15] | K41 | K57 | v[25] | + // Reg 6 [16:23] | K74 | K90 | K106 | K122 | v[26] || Reg 6 [16:23] | K42 | K58 | v[26] | + // Reg 6 [24:31] | K75 | K91 | K107 | K123 | v[27] || Reg 6 [24:31] | K43 | K59 | v[27] | + // Reg 7 [0:7] | K76 | K92 | K108 | K124 | v[28] || Reg 7 [0:7] | K44 | K60 | v[28] | + // Reg 7 [8:15] | K77 | K93 | K109 | K125 | v[29] || Reg 7 [8:15] | K45 | K61 | v[29] | + // Reg 7 [16:23] | K78 | K94 | K110 | K126 | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] | + // Reg 7 [24:31] | K79 | K95 | K111 | K127 | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] | + // clang-format on + + static constexpr int32_t WAVE_SIZE = 64; + + // Here we want to load from rows of A in chunks of 16 elements each. + static constexpr uint32_t chunk_size = 16; + + // each chunk is separated by offset + static constexpr uint32_t chunk_offset = chunk_size * WAVE_SIZE / BLOCK_M; + + // To start the loading process, let's visualize in 2D coords. + // Each thread will load 32 elements. + // We need to know where they start, and where the next elements are. + auto startCoord2D = + std::make_pair(threadIdx.x % BLOCK_M, // Row {0-31} | {0-15} + (threadIdx.x / BLOCK_M) * chunk_size); // Col {0, 16} | {0, 16, 32, 48} + + // auto minorStepCoord2D = std::make_pair(0u, 1u); // read rows + auto majorStepCoord2D = std::make_pair(0, chunk_offset); // read a chunk from a row + + // Flatten to 1D row_major offsets. + auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second; }; + + // BLOCK_K is a stride in A matrix + auto startOffset = row_major(startCoord2D, BLOCK_K); + // auto kMinorOffset = row_major(minorStepCoord2D, BLOCK_K); + auto kMajorOffset = row_major(majorStepCoord2D, BLOCK_K); + + using ARawT = typename scalar_type::type; + using AScalarFragT = vector_type::type; + + union + { + AFragT frag; + AScalarFragT chunks[2]; + } fragA{}; + + auto* fragPtr = reinterpret_cast(input_ptr + startOffset); + fragA.chunks[0] = *fragPtr; + fragPtr = reinterpret_cast(input_ptr + startOffset + kMajorOffset); + fragA.chunks[1] = *fragPtr; + + return fragA.frag; +} + +// Define a load function for scaled A blocks: +// Size: (BLOCK_M x BLOCK_K) +// ASSUMPTION: +// - The scale inputs distributed across 64 lanes. +template +__device__ AFragT load_mx_A_row_major(AType const* input_ptr, + ScaleType const* scale_ptr, + ScaleFragT& fragX) +{ + // clang-format off + // Register Mapping for 16x128: || Register Mapping for 32x64: + // Size | BLOCK_M | BLOCK_M | | BLOCK_M | BLOCK_M | | || Size | BLOCK_M | BLOCK_M | | | + // M | 0 ... 15 | 0 ... 15 | | 0 ... 15 | 0 ... 15 | | Vector || M | 0 ... 31 | 0 ... 31 | Vector | | + // Thread Id | 0 ... 15 | 16 ... 31 | Scale | 32 ... 47 | 48 ... 63 | Scale | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element| Scale | + // Register Element ------------ ------------- ----------|------------ ------------- ----------|-----------|| Register Element |------------|-------------|--------|----------| + // Reg 0 [0:7] | K0 | K16 | x(M,0) | K32 | K48 | x(M,1) | v[0] || Reg 0 [0:7] | K0 | K16 | v[0] | x(M,0) | + // Reg 0 [8:15] | K1 | K17 | x(M,0) | K33 | K49 | x(M,1) | v[1] || Reg 0 [8:15] | K1 | K17 | v[1] | x(M,0) | + // Reg 0 [16:23] | K2 | K18 | x(M,0) | K34 | K50 | x(M,1) | v[2] || Reg 0 [16:23] | K2 | K18 | v[2] | x(M,0) | + // Reg 0 [24:31] | K3 | K19 | x(M,0) | K35 | K51 | x(M,1) | v[3] || Reg 0 [24:31] | K3 | K19 | v[3] | x(M,0) | + // Reg 1 [0:7] | K4 | K20 | x(M,0) | K36 | K52 | x(M,1) | v[4] || Reg 1 [0:7] | K4 | K20 | v[4] | x(M,0) | + // Reg 1 [8:15] | K5 | K21 | x(M,0) | K37 | K53 | x(M,1) | v[5] || Reg 1 [8:15] | K5 | K21 | v[5] | x(M,0) | + // Reg 1 [16:23] | K6 | K22 | x(M,0) | K38 | K54 | x(M,1) | v[6] || Reg 1 [16:23] | K6 | K22 | v[6] | x(M,0) | + // Reg 1 [24:31] | K7 | K23 | x(M,0) | K39 | K55 | x(M,1) | v[7] || Reg 1 [24:31] | K7 | K23 | v[7] | x(M,0) | + // Reg 2 [0:7] | K8 | K24 | x(M,0) | K40 | K56 | x(M,1) | v[8] || Reg 2 [0:7] | K8 | K24 | v[8] | x(M,0) | + // Reg 2 [8:15] | K9 | K25 | x(M,0) | K41 | K57 | x(M,1) | v[9] || Reg 2 [8:15] | K9 | K25 | v[9] | x(M,0) | + // Reg 2 [16:23] | K10 | K26 | x(M,0) | K42 | K58 | x(M,1) | v[10] || Reg 2 [16:23] | K10 | K26 | v[10] | x(M,0) | + // Reg 2 [24:31] | K11 | K27 | x(M,0) | K43 | K59 | x(M,1) | v[11] || Reg 2 [24:31] | K11 | K27 | v[11] | x(M,0) | + // Reg 3 [0:7] | K12 | K28 | x(M,0) | K44 | K60 | x(M,1) | v[12] || Reg 3 [0:7] | K12 | K28 | v[12] | x(M,0) | + // Reg 3 [8:15] | K13 | K29 | x(M,0) | K45 | K61 | x(M,1) | v[13] || Reg 3 [8:15] | K13 | K29 | v[13] | x(M,0) | + // Reg 3 [16:23] | K14 | K30 | x(M,0) | K46 | K62 | x(M,1) | v[14] || Reg 3 [16:23] | K14 | K30 | v[14] | x(M,0) | + // Reg 3 [24:31] | K15 | K31 | x(M,0) | K47 | K63 | x(M,1) | v[15] || Reg 3 [24:31] | K15 | K31 | v[15] | x(M,0) | + // Reg 4 [0:7] | K64 | K80 | x(M,2) | K96 | K112 | x(M,3) | v[16] || Reg 4 [0:7] | K32 | K48 | v[16] | x(M,1) | + // Reg 4 [8:15] | K65 | K81 | x(M,2) | K97 | K113 | x(M,3) | v[17] || Reg 4 [8:15] | K33 | K49 | v[17] | x(M,1) | + // Reg 4 [16:23] | K66 | K82 | x(M,2) | K98 | K114 | x(M,3) | v[18] || Reg 4 [16:23] | K34 | K50 | v[18] | x(M,1) | + // Reg 4 [24:31] | K67 | K83 | x(M,2) | K99 | K115 | x(M,3) | v[19] || Reg 4 [24:31] | K35 | K51 | v[19] | x(M,1) | + // Reg 5 [0:7] | K68 | K84 | x(M,2) | K100 | K116 | x(M,3) | v[20] || Reg 5 [0:7] | K36 | K52 | v[20] | x(M,1) | + // Reg 5 [8:15] | K69 | K85 | x(M,2) | K101 | K117 | x(M,3) | v[21] || Reg 5 [8:15] | K37 | K53 | v[21] | x(M,1) | + // Reg 5 [16:23] | K70 | K86 | x(M,2) | K102 | K118 | x(M,3) | v[22] || Reg 5 [16:23] | K38 | K54 | v[22] | x(M,1) | + // Reg 5 [24:31] | K71 | K87 | x(M,2) | K103 | K119 | x(M,3) | v[23] || Reg 5 [24:31] | K39 | K55 | v[23] | x(M,1) | + // Reg 6 [0:7] | K72 | K88 | x(M,2) | K104 | K120 | x(M,3) | v[24] || Reg 6 [0:7] | K40 | K56 | v[24] | x(M,1) | + // Reg 6 [8:15] | K73 | K89 | x(M,2) | K105 | K121 | x(M,3) | v[25] || Reg 6 [8:15] | K41 | K57 | v[25] | x(M,1) | + // Reg 6 [16:23] | K74 | K90 | x(M,2) | K106 | K122 | x(M,3) | v[26] || Reg 6 [16:23] | K42 | K58 | v[26] | x(M,1) | + // Reg 6 [24:31] | K75 | K91 | x(M,2) | K107 | K123 | x(M,3) | v[27] || Reg 6 [24:31] | K43 | K59 | v[27] | x(M,1) | + // Reg 7 [0:7] | K76 | K92 | x(M,2) | K108 | K124 | x(M,3) | v[28] || Reg 7 [0:7] | K44 | K60 | v[28] | x(M,1) | + // Reg 7 [8:15] | K77 | K93 | x(M,2) | K109 | K125 | x(M,3) | v[29] || Reg 7 [8:15] | K45 | K61 | v[29] | x(M,1) | + // Reg 7 [16:23] | K78 | K94 | x(M,2) | K110 | K126 | x(M,3) | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] | x(M,1) | + // Reg 7 [24:31] | K79 | K95 | x(M,2) | K111 | K127 | x(M,3) | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] | x(M,1) | + // clang-format on + static constexpr uint32_t VW = vectorSize(AFragT{}); + static_assert(VW == BLOCK_X, "Fragment size must be equal to BLOCK_X"); + + // To start the loading process, let's visualize in 2D coords. + // Each thread will load 1 element + // We need to know where they start + auto startCoord2D = std::make_pair(threadIdx.x % BLOCK_M, // Row + (threadIdx.x / BLOCK_M) * VW / BLOCK_X); // Col + + // Flatten to 1D row_major offsets. + auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second; }; + + // BLOCK_K / BLOCK_X is a stride in xA matrix + auto startOffset = row_major(startCoord2D, BLOCK_K / BLOCK_X); + + // obtain 8-bit exponent + fragX = utils::get_exponent_value(scale_ptr[startOffset]) & 0xFF; + + return load_A_row_major(input_ptr); +} + // Define a load function for input B blocks: // Size: (BLOCK_K x BLOCK_N) -// ASSUMPTION: -// - We want contiguous BLOCK_N sized row neighbors in register. -// - Data is in row_major format -// This means: -// - From B we will load K rows of size BLOCK_N to satisfy our input data +// - Data is in col major format +// - Cols are loaded in contiguous chunks that map to corresponding microscales +// - Each col is loaded in chunks of size 16 and each thread loads 32 elements template __device__ BFragT load_B_col_major(BType const* input_ptr) { // clang-format off // Register Mapping for 128x16: || Register Mapping for 64x32: - // Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N | || Size | BLOCK_N | BLOCK_N | - // N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | || N | 0 ... 31 | 0 ... 31 | - // Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Vector || Thread Id | 0 ... 31 | 32 ... 63 | Vector - // Register Element ------------ ------------- ------------ ------------- Element || Register Element ------------ ------------- Element - // Reg 0 [0:7] | K0 | K32 | K64 | K96 | v[0] || Reg 0 [0:7] | K0 | K32 | v[0] - // Reg 0 [8:15] | K1 | K33 | K65 | K97 | v[1] || Reg 0 [8:15] | K1 | K33 | v[1] - // Reg 0 [16:23] | K2 | K34 | K66 | K98 | v[2] || Reg 0 [16:23] | K2 | K34 | v[2] - // Reg 0 [24:31] | K3 | K35 | K67 | K99 | v[3] || Reg 0 [24:31] | K3 | K35 | v[3] - // Reg 1 [0:7] | K4 | K36 | K68 | K100 | v[4] || Reg 1 [0:7] | K4 | K36 | v[4] - // Reg 1 [8:15] | K5 | K37 | K69 | K101 | v[5] || Reg 1 [8:15] | K5 | K37 | v[5] - // Reg 1 [16:23] | K6 | K38 | K70 | K102 | v[6] || Reg 1 [16:23] | K6 | K38 | v[6] - // Reg 1 [24:31] | K7 | K39 | K71 | K103 | v[7] || Reg 1 [24:31] | K7 | K39 | v[7] - // Reg 2 [0:7] | K8 | K40 | K72 | K104 | v[8] || Reg 2 [0:7] | K8 | K40 | v[8] - // Reg 2 [8:15] | K9 | K41 | K73 | K105 | v[9] || Reg 2 [8:15] | K9 | K41 | v[9] - // Reg 2 [16:23] | K10 | K42 | K74 | K106 | v[10] || Reg 2 [16:23] | K10 | K42 | v[10] - // Reg 2 [24:31] | K11 | K43 | K75 | K107 | v[11] || Reg 2 [24:31] | K11 | K43 | v[11] - // Reg 3 [0:7] | K12 | K44 | K76 | K108 | v[12] || Reg 3 [0:7] | K12 | K44 | v[12] - // Reg 3 [8:15] | K13 | K45 | K77 | K109 | v[13] || Reg 3 [8:15] | K13 | K45 | v[13] - // Reg 3 [16:23] | K14 | K46 | K78 | K110 | v[14] || Reg 3 [16:23] | K14 | K46 | v[14] - // Reg 3 [24:31] | K15 | K47 | K79 | K111 | v[15] || Reg 3 [24:31] | K15 | K47 | v[15] - // Reg 4 [0:7] | K16 | K48 | K80 | K112 | v[16] || Reg 4 [0:7] | K16 | K48 | v[16] - // Reg 4 [8:15] | K17 | K49 | K81 | K113 | v[17] || Reg 4 [8:15] | K17 | K49 | v[17] - // Reg 4 [16:23] | K18 | K50 | K82 | K114 | v[18] || Reg 4 [16:23] | K18 | K50 | v[18] - // Reg 4 [24:31] | K19 | K51 | K83 | K115 | v[19] || Reg 4 [24:31] | K19 | K51 | v[19] - // Reg 5 [0:7] | K20 | K52 | K84 | K116 | v[20] || Reg 5 [0:7] | K20 | K52 | v[20] - // Reg 5 [8:15] | K21 | K53 | K85 | K117 | v[21] || Reg 5 [8:15] | K21 | K53 | v[21] - // Reg 5 [16:23] | K22 | K54 | K86 | K118 | v[22] || Reg 5 [16:23] | K22 | K54 | v[22] - // Reg 5 [24:31] | K23 | K55 | K87 | K119 | v[23] || Reg 5 [24:31] | K23 | K55 | v[23] - // Reg 6 [0:7] | K24 | K56 | K88 | K120 | v[24] || Reg 6 [0:7] | K24 | K56 | v[24] - // Reg 6 [8:15] | K25 | K57 | K89 | K121 | v[25] || Reg 6 [8:15] | K25 | K57 | v[25] - // Reg 6 [16:23] | K26 | K58 | K90 | K122 | v[26] || Reg 6 [16:23] | K26 | K58 | v[26] - // Reg 6 [24:31] | K27 | K59 | K91 | K123 | v[27] || Reg 6 [24:31] | K27 | K59 | v[27] - // Reg 7 [0:7] | K28 | K60 | K92 | K124 | v[28] || Reg 7 [0:7] | K28 | K60 | v[28] - // Reg 7 [8:15] | K29 | K61 | K93 | K125 | v[29] || Reg 7 [8:15] | K29 | K61 | v[29] - // Reg 7 [16:23] | K30 | K62 | K94 | K126 | v[30] || Reg 7 [16:23] | K30 | K62 | v[30] - // Reg 7 [24:31] | K31 | K63 | K95 | K127 | v[31] || Reg 7 [24:31] | K31 | K63 | v[31] + // Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N | || Size | BLOCK_N | BLOCK_N | | + // N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || N | 0 ... 31 | 0 ... 31 | Vector | + // Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element| + // Register Element |------------|-------------|------------|-------------|-----------|| Register Element |------------|-------------|--------| + // Reg 0 [0:7] | K0 | K16 | K32 | K48 | v[0] || Reg 0 [0:7] | K0 | K16 | v[0] | + // Reg 0 [8:15] | K1 | K17 | K33 | K49 | v[1] || Reg 0 [8:15] | K1 | K17 | v[1] | + // Reg 0 [16:23] | K2 | K18 | K34 | K50 | v[2] || Reg 0 [16:23] | K2 | K18 | v[2] | + // Reg 0 [24:31] | K3 | K19 | K35 | K51 | v[3] || Reg 0 [24:31] | K3 | K19 | v[3] | + // Reg 1 [0:7] | K4 | K20 | K36 | K52 | v[4] || Reg 1 [0:7] | K4 | K20 | v[4] | + // Reg 1 [8:15] | K5 | K21 | K37 | K53 | v[5] || Reg 1 [8:15] | K5 | K21 | v[5] | + // Reg 1 [16:23] | K6 | K22 | K38 | K54 | v[6] || Reg 1 [16:23] | K6 | K22 | v[6] | + // Reg 1 [24:31] | K7 | K23 | K39 | K55 | v[7] || Reg 1 [24:31] | K7 | K23 | v[7] | + // Reg 2 [0:7] | K8 | K24 | K40 | K56 | v[8] || Reg 2 [0:7] | K8 | K24 | v[8] | + // Reg 2 [8:15] | K9 | K25 | K41 | K57 | v[9] || Reg 2 [8:15] | K9 | K25 | v[9] | + // Reg 2 [16:23] | K10 | K26 | K42 | K58 | v[10] || Reg 2 [16:23] | K10 | K26 | v[10] | + // Reg 2 [24:31] | K11 | K27 | K43 | K59 | v[11] || Reg 2 [24:31] | K11 | K27 | v[11] | + // Reg 3 [0:7] | K12 | K28 | K44 | K60 | v[12] || Reg 3 [0:7] | K12 | K28 | v[12] | + // Reg 3 [8:15] | K13 | K29 | K45 | K61 | v[13] || Reg 3 [8:15] | K13 | K29 | v[13] | + // Reg 3 [16:23] | K14 | K30 | K46 | K62 | v[14] || Reg 3 [16:23] | K14 | K30 | v[14] | + // Reg 3 [24:31] | K15 | K31 | K47 | K63 | v[15] || Reg 3 [24:31] | K15 | K31 | v[15] | + // Reg 4 [0:7] | K64 | K80 | K96 | K112 | v[16] || Reg 4 [0:7] | K32 | K48 | v[16] | + // Reg 4 [8:15] | K65 | K81 | K97 | K113 | v[17] || Reg 4 [8:15] | K33 | K49 | v[17] | + // Reg 4 [16:23] | K66 | K82 | K98 | K114 | v[18] || Reg 4 [16:23] | K34 | K50 | v[18] | + // Reg 4 [24:31] | K67 | K83 | K99 | K115 | v[19] || Reg 4 [24:31] | K35 | K51 | v[19] | + // Reg 5 [0:7] | K68 | K84 | K100 | K116 | v[20] || Reg 5 [0:7] | K36 | K52 | v[20] | + // Reg 5 [8:15] | K69 | K85 | K101 | K117 | v[21] || Reg 5 [8:15] | K37 | K53 | v[21] | + // Reg 5 [16:23] | K70 | K86 | K102 | K118 | v[22] || Reg 5 [16:23] | K38 | K54 | v[22] | + // Reg 5 [24:31] | K71 | K87 | K103 | K119 | v[23] || Reg 5 [24:31] | K39 | K55 | v[23] | + // Reg 6 [0:7] | K72 | K88 | K104 | K120 | v[24] || Reg 6 [0:7] | K40 | K56 | v[24] | + // Reg 6 [8:15] | K73 | K89 | K105 | K121 | v[25] || Reg 6 [8:15] | K41 | K57 | v[25] | + // Reg 6 [16:23] | K74 | K90 | K106 | K122 | v[26] || Reg 6 [16:23] | K42 | K58 | v[26] | + // Reg 6 [24:31] | K75 | K91 | K107 | K123 | v[27] || Reg 6 [24:31] | K43 | K59 | v[27] | + // Reg 7 [0:7] | K76 | K92 | K108 | K124 | v[28] || Reg 7 [0:7] | K44 | K60 | v[28] | + // Reg 7 [8:15] | K77 | K93 | K109 | K125 | v[29] || Reg 7 [8:15] | K45 | K61 | v[29] | + // Reg 7 [16:23] | K78 | K94 | K110 | K126 | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] | + // Reg 7 [24:31] | K79 | K95 | K111 | K127 | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] | // clang-format on - // Here we want to load a BLOCK_K x BLOCK_N block of data. - static constexpr uint32_t VW = vectorSize(BFragT{}); + static constexpr int32_t WAVE_SIZE = 64; + + // Here we want to load from cols of B in chunks of 16 elements each. + static constexpr uint32_t chunk_size = 16; + + // each chunk is separated by an offset + static constexpr uint32_t chunk_offset = chunk_size * WAVE_SIZE / BLOCK_N; // 32 or 64 // To start the loading process, let's visualize in 2D coords. // Each thread will load 32 elements. // We need to know where they start, and where the next elements are. - auto startCoord2D = std::make_pair((threadIdx.x / BLOCK_N) * VW, // Row - threadIdx.x % BLOCK_N); // Col + auto startCoord2D = + std::make_pair((threadIdx.x / BLOCK_N) * chunk_size, // Row {0, 16} | {0, 16, 32, 48} + threadIdx.x % BLOCK_N); // Col {0-31} | {0-15} // Flatten to 1D col_major offsets. auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; }; - auto startOffset = col_major(startCoord2D, BLOCK_K); + // auto minorStepCoord2D = std::make_pair(1u, 0u); // read cols + auto majorStepCoord2D = std::make_pair(chunk_offset, 0); // read a chunk from a col - auto const* fragPtr = reinterpret_cast(input_ptr + startOffset); - return *fragPtr; + // BLOCK_K is a stride in B matrix + auto startOffset = col_major(startCoord2D, BLOCK_K); + // auto kMinorOffset = col_major(minorStepCoord2D, BLOCK_K); + auto kMajorOffset = col_major(majorStepCoord2D, BLOCK_K); + + using BRawT = typename scalar_type::type; + using BScalarFragT = vector_type::type; + + union + { + BFragT frag; + BScalarFragT chunks[2]; + } fragB{}; + + auto* fragPtr = reinterpret_cast(input_ptr + startOffset); + fragB.chunks[0] = *fragPtr; + fragPtr = reinterpret_cast(input_ptr + startOffset + kMajorOffset); + fragB.chunks[1] = *fragPtr; + + return fragB.frag; +} + +// Define a load function for scaled B blocks: +// Size: (BLOCK_K x BLOCK_N) +// ASSUMPTION: +// - The scale inputs distributed across 64 lanes. +template +__device__ BFragT load_mx_B_col_major(BType const* input_ptr, + ScaleType const* scale_ptr, + ScaleFragT& fragX) + +{ + // clang-format off + // Register Mapping for 128x16: || Register Mapping for 64x32: + // Size | BLOCK_N | BLOCK_N | | BLOCK_N | BLOCK_N | | || Size | BLOCK_N | BLOCK_N | | | + // N | 0 ... 15 | 0 ... 15 | | 0 ... 15 | 0 ... 15 | | Vector || N | 0 ... 31 | 0 ... 31 | Vector | | + // Thread Id | 0 ... 15 | 16 ... 31 | Scale | 32 ... 47 | 48 ... 63 | Scale | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element| Scale | + // Register Element ------------ ------------- ----------|------------ ------------- ----------|-----------|| Register Element |------------|-------------|--------|----------| + // Reg 0 [0:7] | K0 | K16 | x(0,N) | K32 | K48 | x(1,N) | v[0] || Reg 0 [0:7] | K0 | K16 | v[0] | x(0,N) | + // Reg 0 [8:15] | K1 | K17 | x(0,N) | K33 | K49 | x(1,N) | v[1] || Reg 0 [8:15] | K1 | K17 | v[1] | x(0,N) | + // Reg 0 [16:23] | K2 | K18 | x(0,N) | K34 | K50 | x(1,N) | v[2] || Reg 0 [16:23] | K2 | K18 | v[2] | x(0,N) | + // Reg 0 [24:31] | K3 | K19 | x(0,N) | K35 | K51 | x(1,N) | v[3] || Reg 0 [24:31] | K3 | K19 | v[3] | x(0,N) | + // Reg 1 [0:7] | K4 | K20 | x(0,N) | K36 | K52 | x(1,N) | v[4] || Reg 1 [0:7] | K4 | K20 | v[4] | x(0,N) | + // Reg 1 [8:15] | K5 | K21 | x(0,N) | K37 | K53 | x(1,N) | v[5] || Reg 1 [8:15] | K5 | K21 | v[5] | x(0,N) | + // Reg 1 [16:23] | K6 | K22 | x(0,N) | K38 | K54 | x(1,N) | v[6] || Reg 1 [16:23] | K6 | K22 | v[6] | x(0,N) | + // Reg 1 [24:31] | K7 | K23 | x(0,N) | K39 | K55 | x(1,N) | v[7] || Reg 1 [24:31] | K7 | K23 | v[7] | x(0,N) | + // Reg 2 [0:7] | K8 | K24 | x(0,N) | K40 | K56 | x(1,N) | v[8] || Reg 2 [0:7] | K8 | K24 | v[8] | x(0,N) | + // Reg 2 [8:15] | K9 | K25 | x(0,N) | K41 | K57 | x(1,N) | v[9] || Reg 2 [8:15] | K9 | K25 | v[9] | x(0,N) | + // Reg 2 [16:23] | K10 | K26 | x(0,N) | K42 | K58 | x(1,N) | v[10] || Reg 2 [16:23] | K10 | K26 | v[10] | x(0,N) | + // Reg 2 [24:31] | K11 | K27 | x(0,N) | K43 | K59 | x(1,N) | v[11] || Reg 2 [24:31] | K11 | K27 | v[11] | x(0,N) | + // Reg 3 [0:7] | K12 | K28 | x(0,N) | K44 | K60 | x(1,N) | v[12] || Reg 3 [0:7] | K12 | K28 | v[12] | x(0,N) | + // Reg 3 [8:15] | K13 | K29 | x(0,N) | K45 | K61 | x(1,N) | v[13] || Reg 3 [8:15] | K13 | K29 | v[13] | x(0,N) | + // Reg 3 [16:23] | K14 | K30 | x(0,N) | K46 | K62 | x(1,N) | v[14] || Reg 3 [16:23] | K14 | K30 | v[14] | x(0,N) | + // Reg 3 [24:31] | K15 | K31 | x(0,N) | K47 | K63 | x(1,N) | v[15] || Reg 3 [24:31] | K15 | K31 | v[15] | x(0,N) | + // Reg 4 [0:7] | K64 | K80 | x(2,N) | K96 | K112 | x(3,N) | v[16] || Reg 4 [0:7] | K32 | K48 | v[16] | x(1,N) | + // Reg 4 [8:15] | K65 | K81 | x(2,N) | K97 | K113 | x(3,N) | v[17] || Reg 4 [8:15] | K33 | K49 | v[17] | x(1,N) | + // Reg 4 [16:23] | K66 | K82 | x(2,N) | K98 | K114 | x(3,N) | v[18] || Reg 4 [16:23] | K34 | K50 | v[18] | x(1,N) | + // Reg 4 [24:31] | K67 | K83 | x(2,N) | K99 | K115 | x(3,N) | v[19] || Reg 4 [24:31] | K35 | K51 | v[19] | x(1,N) | + // Reg 5 [0:7] | K68 | K84 | x(2,N) | K100 | K116 | x(3,N) | v[20] || Reg 5 [0:7] | K36 | K52 | v[20] | x(1,N) | + // Reg 5 [8:15] | K69 | K85 | x(2,N) | K101 | K117 | x(3,N) | v[21] || Reg 5 [8:15] | K37 | K53 | v[21] | x(1,N) | + // Reg 5 [16:23] | K70 | K86 | x(2,N) | K102 | K118 | x(3,N) | v[22] || Reg 5 [16:23] | K38 | K54 | v[22] | x(1,N) | + // Reg 5 [24:31] | K71 | K87 | x(2,N) | K103 | K119 | x(3,N) | v[23] || Reg 5 [24:31] | K39 | K55 | v[23] | x(1,N) | + // Reg 6 [0:7] | K72 | K88 | x(2,N) | K104 | K120 | x(3,N) | v[24] || Reg 6 [0:7] | K40 | K56 | v[24] | x(1,N) | + // Reg 6 [8:15] | K73 | K89 | x(2,N) | K105 | K121 | x(3,N) | v[25] || Reg 6 [8:15] | K41 | K57 | v[25] | x(1,N) | + // Reg 6 [16:23] | K74 | K90 | x(2,N) | K106 | K122 | x(3,N) | v[26] || Reg 6 [16:23] | K42 | K58 | v[26] | x(1,N) | + // Reg 6 [24:31] | K75 | K91 | x(2,N) | K107 | K123 | x(3,N) | v[27] || Reg 6 [24:31] | K43 | K59 | v[27] | x(1,N) | + // Reg 7 [0:7] | K76 | K92 | x(2,N) | K108 | K124 | x(3,N) | v[28] || Reg 7 [0:7] | K44 | K60 | v[28] | x(1,N) | + // Reg 7 [8:15] | K77 | K93 | x(2,N) | K109 | K125 | x(3,N) | v[29] || Reg 7 [8:15] | K45 | K61 | v[29] | x(1,N) | + // Reg 7 [16:23] | K78 | K94 | x(2,N) | K110 | K126 | x(3,N) | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] | x(1,N) | + // Reg 7 [24:31] | K79 | K95 | x(2,N) | K111 | K127 | x(3,N) | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] | x(1,N) | + + // clang-format on + static constexpr uint32_t VW = vectorSize(BFragT{}); + static_assert(VW == BLOCK_X, "Fragment size must be equal to BLOCK_X"); + + // To start the loading process, let's visualize in 2D coords. + // Each thread will load 1 element + // We need to know where to start + auto startCoord2D = std::make_pair((threadIdx.x / BLOCK_N) * VW / BLOCK_X, // Row + threadIdx.x % BLOCK_N); // Col + + // Flatten to 1D col_major offsets. + auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; }; + + auto startOffset = col_major(startCoord2D, BLOCK_K / BLOCK_X); + + // obtain 8-bit exponent + fragX = utils::get_exponent_value(scale_ptr[startOffset]) & 0xFF; + + return load_B_col_major(input_ptr); } // Define a store function for C @@ -309,6 +617,129 @@ struct store_C_col_major } }; +// Define a store function for C +// Size: (BLOCK_M x BLOCK_N) +// ASSUMPTION: +// - We want contiguous BLOCK_N sized row neighbors in register. +// - Data is in row major format +template +struct store_C_row_major; + +// Here we want to store a 16x16 block of data. +// +// Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N | +// N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | +// Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Vector +// Register Element ------------ ------------- ------------ -------------- Element +// Reg0 | M0 | M4 | M8 | M12 | v[0] +// Reg1 | M1 | M5 | M9 | M13 | v[1] +// Reg2 | M2 | M6 | M10 | M14 | v[2] +// Reg3 | M3 | M7 | M11 | M15 | v[3] +template +struct store_C_row_major +{ + __device__ void operator()(CType* output, CFragT cFrag) + { + static constexpr uint32_t VW = vectorSize(cFrag); // 4 + static constexpr uint32_t Dim = 16; + + // Each thread will load 4 elements. + // We need to know where they start, and where the next elements are. + auto startCoord2D = std::make_pair((threadIdx.x / Dim) * VW, // Row + threadIdx.x % Dim); // Col + auto stepCoord2D = std::make_pair(1u, 0u); + + // Flatten to 1D row_major offsets. + auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second; }; + + auto startOffset = row_major(startCoord2D, 16); + auto kOffset = row_major(stepCoord2D, 16); + + auto* fragPtr = reinterpret_cast(output + startOffset); + *fragPtr = cFrag; + + // If you notice carefully, kOffset != 1. + // This means the following is vector is updated with 4 non-contiguous offsets, + // which the compiler will separate into 4 different global_store_dword instructions. + output[startOffset] = cFrag[0]; // v[0] = Reg 0 + output[startOffset + kOffset] = cFrag[1]; // v[1] = Reg 1 + output[startOffset + 2 * kOffset] = cFrag[2]; // v[2] = Reg 2 + output[startOffset + 3 * kOffset] = cFrag[3]; // v[3] = Reg 3 + } +}; + +// Here we want to store a 32x32 block of data. +// Register Mapping: + +// Size | BLOCK_N | BLOCK_N | +// N | 0 ... 31 | 0 ... 31 | +// Thread Id | 0 ... 31 | 32 ... 63 | Vector +// Register Element ------------ ------------- Element +// Reg0 | M0 | M4 | v[0] +// Reg1 | M1 | M5 | v[1] +// Reg2 | M2 | M6 | v[2] +// Reg3 | M3 | M7 | v[3] +// ____________ _____________ +// Reg4 | M8 | M12 | v[4] +// Reg5 | M9 | M13 | v[5] +// Reg6 | M10 | M14 | v[6] +// Reg7 | M11 | M15 | v[7] +// ____________ _____________ +// Reg8 | M16 | M20 | v[8] +// Reg9 | M17 | M21 | v[9] +// Reg10 | M18 | M22 | v[10] +// Reg11 | M19 | M23 | v[11] +// ____________ _____________ +// Reg12 | M24 | M28 | v[12] +// Reg13 | M25 | M29 | v[13] +// Reg14 | M26 | M30 | v[14] +// Reg15 | M27 | M31 | v[15] + +template +struct store_C_row_major +{ + __device__ void operator()(CType* output, CFragT cFrag) + { + static constexpr uint32_t WAVE_SIZE = 64; + static constexpr uint32_t VW = 4; // This VW is per 'chunk' + static constexpr uint32_t Dim = 32; // BLOCK_N + static constexpr uint32_t M_PER_VW_CHUNK = VW * WAVE_SIZE / 32; // 8 + + auto startCoord2D = std::make_pair((threadIdx.x / Dim) * VW, // Row + threadIdx.x % Dim); // Col + + // Minor step for each 'chunk' + auto minorStepCoord2D = std::make_pair(1u, 0u); + + // Major step between 'chunks' + auto majorStepCoord2D = std::make_pair(M_PER_VW_CHUNK, 0); + + // Flatten to 1D row_major offsets. + auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second; }; + + auto startOffset = row_major(startCoord2D, 32); + auto kMinorOffset = row_major(minorStepCoord2D, 32); + auto kMajorOffset = row_major(majorStepCoord2D, 32); + + output[startOffset] = cFrag[0]; // v[0] = Reg 0 + output[startOffset + kMinorOffset] = cFrag[1]; // v[1] = Reg 1 + output[startOffset + 2 * kMinorOffset] = cFrag[2]; // v[2] = Reg 2 + output[startOffset + 3 * kMinorOffset] = cFrag[3]; // v[3] = Reg 3 + output[startOffset + kMajorOffset] = cFrag[4]; // v[4] = Reg 4 + output[startOffset + kMajorOffset + kMinorOffset] = cFrag[5]; // v[5] = Reg 5 + output[startOffset + kMajorOffset + 2 * kMinorOffset] = cFrag[6]; // v[6] = Reg 6 + output[startOffset + kMajorOffset + 3 * kMinorOffset] = cFrag[7]; // v[7] = Reg 7 + output[startOffset + 2 * kMajorOffset] = cFrag[8]; // v[8] = Reg 8 + output[startOffset + 2 * kMajorOffset + kMinorOffset] = cFrag[9]; // v[9] = Reg 9 + output[startOffset + 2 * kMajorOffset + 2 * kMinorOffset] = cFrag[10]; // v[10] = Reg 10 + output[startOffset + 2 * kMajorOffset + 3 * kMinorOffset] = cFrag[11]; // v[11] = Reg 11 + output[startOffset + 3 * kMajorOffset] = cFrag[12]; // v[12] = Reg 12 + output[startOffset + 3 * kMajorOffset + kMinorOffset] = cFrag[13]; // v[13] = Reg 13 + output[startOffset + 3 * kMajorOffset + 2 * kMinorOffset] = cFrag[14]; // v[14] = Reg 14 + output[startOffset + 3 * kMajorOffset + 3 * kMinorOffset] = cFrag[15]; // v[15] = Reg 15 + } +}; + template {}; storeC(c, fragC); } + +template +__global__ void +matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb, CType* c) +{ + constexpr int WAVE_SIZE = 64; + assert(threadIdx.x < WAVE_SIZE); + assert(blockDim.x == 1 && blockDim.y == 1 && blockDim.z == 1); + + using AFragT = vector_type::type; + using BFragT = vector_type::type; + using CFragT = vector_type::type; + using AccumFragT = vector_type; + using RawAccumFragT = vector_type::type; + using ScaleFragT = int32_t; + + // Create frags + auto fragA = AFragT{}; + auto fragB = BFragT{}; + auto fragC = CFragT{}; + auto fragAcc = AccumFragT{0}; + auto fragXa = ScaleFragT{0}; + auto fragXb = ScaleFragT{0}; + + // Load the inputs. + // A = col major, BLOCK_M x BLOCK_K + fragA = load_mx_A_row_major( + a, xa, fragXa); + + // B = col major, BLOCK_K x BLOCK_N + fragB = load_mx_B_col_major( + b, xb, fragXb); + + // Scaled Matrix multiply-accumulate using MFMA units + // Accumulation intermediate = BLOCK_M x BLOCK_N + mfma_type_selector{}( + fragA, fragXa, fragB, fragXb, fragAcc); + + for(int i = 0; i < vectorSize(fragC); ++i) + { + fragC[i] = type_convert(fragAcc.template AsType()[Number<0>{}][i]); + } + + auto storeC = store_C_row_major{}; + storeC(c, fragC); +} + /** * @brief Structure to hold dimension parameters for GEMM tensors. * @@ -373,6 +859,225 @@ struct GemmParams ck::index_t StrideC = -1; }; +namespace mxmfma_test { +template +void RunHostGEMM(const Tensor& A, + const Tensor& a_scales, + const Tensor& B, + const Tensor& b_scales, + Tensor& C) +{ + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMXGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + A, a_scales, B, b_scales, C, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); +} + +template +bool RunDeviceGEMM(KernelType kernel, + const Tensor& A, + const Tensor& a_scales, + const Tensor& B, + const Tensor& b_scales, + Tensor& C) +{ + DeviceMem a_m_k_device_buf(sizeof(ADataType) * A.mDesc.GetElementSpaceSize()); + DeviceMem a_scales_device_buf(sizeof(ScaleType) * a_scales.mDesc.GetElementSpaceSize()); + DeviceMem b_n_k_device_buf(sizeof(BDataType) * B.mDesc.GetElementSpaceSize()); + DeviceMem b_scales_device_buf(sizeof(ScaleType) * b_scales.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * C.mDesc.GetElementSpaceSize()); + + a_m_k_device_buf.ToDevice(A.mData.data()); + a_scales_device_buf.ToDevice(a_scales.mData.data()); + b_n_k_device_buf.ToDevice(B.mData.data()); + b_scales_device_buf.ToDevice(b_scales.mData.data()); + kernel<<<1, 64>>>(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(a_scales_device_buf.GetDeviceBuffer()), + static_cast(b_n_k_device_buf.GetDeviceBuffer()), + static_cast(b_scales_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer())); + c_m_n_device_buf.FromDevice(C.mData.data()); + + return true; +} + +template +struct TestMXMFMA +{ + auto PrepareGemmTensors(const GemmParams& params, index_t init) + { + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(std::is_same::value) + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({stride, 1})); + } + else + { + return HostTensorDescriptor(std::vector({row, col}), + std::vector({1, stride})); + } + }; + + Tensor a_m_k( + f_host_tensor_descriptor(params.M, params.K, params.StrideA, ALayout{})); + Tensor a_scales( + f_host_tensor_descriptor(params.M, params.K / BLOCK_X, params.K / BLOCK_X, ALayout{})); + Tensor b_n_k( + f_host_tensor_descriptor(params.K, params.N, params.StrideB, BLayout{})); + Tensor b_scales( + f_host_tensor_descriptor(params.K / BLOCK_X, params.N, params.K / BLOCK_X, BLayout{})); + Tensor c_m_n_host_result( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + Tensor c_m_n_device_result( + f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); + + switch(init) + { + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); + a_scales.GenerateTensorValue( + GeneratorTensor_1{ScaleType{0.015625f}}); // 1/64 + // NOTE: not all numbers are representable in FP8, BF8, etc. + // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 16 18 20 20 20 22 24 24 24 26 28 28 28 30 32 + b_n_k.GenerateTensorValue(GeneratorTensor_Sequential{}); + b_scales.GenerateTensorValue(GeneratorTensor_1{ScaleType{1.0f}}); + break; + case 1: + // results in C = {K} + a_m_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); + a_scales.GenerateTensorValue(GeneratorTensor_1{ScaleType{512.0f}}); + b_n_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); + b_scales.GenerateTensorValue(GeneratorTensor_1{ScaleType{1.0f / 512}}); + break; + case 2: + // expect small round off errors + a_m_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); + a_scales.GenerateTensorValue( + GeneratorTensor_2{126, 129}); // scales: {0.5, 1, 2} + + b_n_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); + b_scales.GenerateTensorValue(GeneratorTensor_2{126, 129}); + break; + + case 3: + // expect small round off errors + a_m_k.GenerateTensorValue(GeneratorTensor_4(0, 1)); + a_scales.GenerateTensorValue( + GeneratorTensor_2{126, 129}); // scales: {0.5, 1, 2} + b_n_k.GenerateTensorValue(GeneratorTensor_4(0, 1)); + b_scales.GenerateTensorValue( + GeneratorTensor_2{126, 129}); // scales: {0.5, 1, 2} + break; + default: + // all initial values are representable in FP8, BF8 + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 6}); // Z[-5,5] + a_scales.GenerateTensorValue( + GeneratorTensor_2{122, 129}); // scales: [1/32,..., 2] + b_n_k.GenerateTensorValue(GeneratorTensor_2{-5, 6}); // Z[-5,5] + b_scales.GenerateTensorValue( + GeneratorTensor_2{122, 129}); // scales: [1/32,..., 2] + + break; + } + + return std::make_tuple( + a_m_k, a_scales, b_n_k, b_scales, c_m_n_host_result, c_m_n_device_result); + } + + auto operator()(const DeviceMFMA& mfma_kernel, index_t init) + { + // Arrange + GemmParams params; + params.M = BLOCK_M; + params.N = BLOCK_N; + params.K = BLOCK_K; + + auto f_get_default_stride = [](std::size_t row, + std::size_t col, + ck::index_t stride, + auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + params.StrideA = f_get_default_stride(BLOCK_M, BLOCK_K, params.StrideA, ALayout{}); + params.StrideB = f_get_default_stride(BLOCK_K, BLOCK_N, params.StrideB, BLayout{}); + params.StrideC = f_get_default_stride(BLOCK_M, BLOCK_N, params.StrideC, CLayout{}); + + auto host_tensors = PrepareGemmTensors(params, init); + + const Tensor& a = std::get<0>(host_tensors); + const Tensor& a_scales = std::get<1>(host_tensors); + const Tensor& b = std::get<2>(host_tensors); + const Tensor& b_scales = std::get<3>(host_tensors); + Tensor& c_host = std::get<4>(host_tensors); + Tensor& c_device = std::get<5>(host_tensors); + + RunHostGEMM(a, a_scales, b, b_scales, c_host); + + RunDeviceGEMM(mfma_kernel, a, a_scales, b, b_scales, c_device); + + bool res = false; + if constexpr(std::is_same::value || + std::is_same::value) + { + res = ck::utils::check_err(c_device.mData, c_host.mData); + } + else + { + std::cout << "UNSUPPORTED CDataType" << std::endl; + } + + return res; + } +}; + +} // namespace mxmfma_test + namespace mfma_test { template