From 789ed576626450b2eebf7ea12a7eca2a4c905794 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Tue, 6 May 2025 09:24:00 -0500 Subject: [PATCH] Add FP4 MX MFMA tests (#2151) * Add conversion tests * Fix ctor * Fix nan logic * Fix conversion logic * Permute packed f4_t values * Fix conversion to float, repack vector elements * Fix device tests * Permute elements in a vector * Add a repro test * Add a conversion for a repro test * Update test vectors * Update conversion * Fix the test * Update test vector generator * Fix vector sr conversion * Permute conversion args * Update conversion * Test * Fix packing * Simplify conversion function * Pack conversion in a loop * Pack conversion in a loop * Pack another conversion in a loop * Pack one more conversion in a loop * Pack the last conversion in a loop * Clean up * Add ops * Add tests * Add missing utils * Update reference mx gemm * Add f4x2 init mode * Update host tensor utils * Update chunk size for f4x2 * Add non scaled ops * Add a type utility * Update non scaled reference kernel * Add non scaled tests * Debug mfma arguments * Add more debug info * Update chunk size * Update data layout * Add more debugging * Fix B stride * Fix reference gemm * Fix build * One more reference fix * Add more debug info * Disable some tests * Enable tests * Add fp4 dimensions * Update reference kernels * Temp edits * Remove leftovers * Fix conflicts * Clean up * More clean up * Revert "More clean up" This reverts commit d8d35a0846a8c2f0ccc7defe5f4fc7cc4ef36760. * Add layouts to tests --------- Co-authored-by: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> [ROCm/composable_kernel commit: 8a0d659f92897e1ae99e4dc0ea4842a2c78170ab] --- include/ck/library/utility/host_tensor.hpp | 21 +- .../library/utility/host_tensor_generator.hpp | 48 ++- include/ck/utility/amd_xdlops.hpp | 122 +++++++ include/ck/utility/data_type.hpp | 7 + .../cpu/reference_gemm.hpp | 20 ++ .../cpu/reference_mx_gemm.hpp | 50 ++- test/mx_mfma_op/mx_mfma_op.cpp | 114 ++++++- test/mx_mfma_op/mx_mfma_op.hpp | 307 +++++++++++++++--- 8 files changed, 610 insertions(+), 79 deletions(-) diff --git a/include/ck/library/utility/host_tensor.hpp b/include/ck/library/utility/host_tensor.hpp index 2cbca29afc..71417ce7bf 100644 --- a/include/ck/library/utility/host_tensor.hpp +++ b/include/ck/library/utility/host_tensor.hpp @@ -51,7 +51,8 @@ std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim) { os << ck::type_convert(v); } - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v || + std::is_same_v) { const auto packed_floats = ck::type_convert(v); const ck::vector_type vector_of_floats{packed_floats}; @@ -359,7 +360,8 @@ struct Tensor std::size_t GetElementSpaceSize() const { - if constexpr(ck::is_same_v, ck::pk_i4_t>) + if constexpr(ck::is_same_v, ck::pk_i4_t> || + ck::is_same_v, ck::f4x2_pk_t>) { return (mDesc.GetElementSpaceSize() + 1) / 2; } @@ -514,7 +516,8 @@ struct Tensor template std::size_t GetOffsetFromMultiIndex(Is... is) const { - if constexpr(ck::is_same_v, ck::pk_i4_t>) + if constexpr(ck::is_same_v, ck::pk_i4_t> || + ck::is_same_v, ck::f4x2_pk_t>) { return mDesc.GetOffsetFromMultiIndex(is...) / 2; } @@ -527,7 +530,8 @@ struct Tensor template T& operator()(Is... is) { - if constexpr(ck::is_same_v, ck::pk_i4_t>) + if constexpr(ck::is_same_v, ck::pk_i4_t> || + ck::is_same_v, ck::f4x2_pk_t>) { return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2]; } @@ -540,7 +544,8 @@ struct Tensor template const T& operator()(Is... is) const { - if constexpr(ck::is_same_v, ck::pk_i4_t>) + if constexpr(ck::is_same_v, ck::pk_i4_t> || + ck::is_same_v, ck::f4x2_pk_t>) { return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2]; } @@ -552,7 +557,8 @@ struct Tensor T& operator()(std::vector idx) { - if constexpr(ck::is_same_v, ck::pk_i4_t>) + if constexpr(ck::is_same_v, ck::pk_i4_t> || + ck::is_same_v, ck::f4x2_pk_t>) { return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2]; } @@ -564,7 +570,8 @@ struct Tensor const T& operator()(std::vector idx) const { - if constexpr(ck::is_same_v, ck::pk_i4_t>) + if constexpr(ck::is_same_v, ck::pk_i4_t> || + ck::is_same_v, ck::f4x2_pk_t>) { return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2]; } diff --git a/include/ck/library/utility/host_tensor_generator.hpp b/include/ck/library/utility/host_tensor_generator.hpp index 274051da83..785f74a3c0 100644 --- a/include/ck/library/utility/host_tensor_generator.hpp +++ b/include/ck/library/utility/host_tensor_generator.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -81,6 +81,18 @@ struct GeneratorTensor_1 } }; +template <> +struct GeneratorTensor_1 +{ + float value = 1.0; + + template + ck::f4x2_pk_t operator()(Is...) + { + return ck::f4x2_pk_t{ck::type_convert(ck::float2_t{value, value})}; + } +}; + template <> struct GeneratorTensor_1 { @@ -209,6 +221,21 @@ struct GeneratorTensor_2 } }; +template <> +struct GeneratorTensor_2 +{ + int min_value = 0; + int max_value = 1; + + template + ck::f4x2_pk_t operator()(Is...) + { + float tmp0 = (std::rand() % (max_value - min_value)) + min_value; + float tmp1 = (std::rand() % (max_value - min_value)) + min_value; + return ck::f4x2_pk_t{ck::type_convert(ck::float2_t{tmp0, tmp1})}; + } +}; + template struct GeneratorTensor_3 { @@ -296,6 +323,25 @@ struct GeneratorTensor_3 } }; +template <> +struct GeneratorTensor_3 +{ + float min_value = 0; + float max_value = 1; + + template + ck::f4x2_pk_t operator()(Is...) + { + float tmp0 = float(std::rand()) / float(RAND_MAX); + float tmp1 = float(std::rand()) / float(RAND_MAX); + + float fp32_tmp0 = min_value + tmp0 * (max_value - min_value); + float fp32_tmp1 = min_value + tmp1 * (max_value - min_value); + + return ck::f4x2_pk_t{ck::type_convert(ck::float2_t{fp32_tmp0, fp32_tmp1})}; + } +}; + template struct GeneratorTensor_4 { diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index 71e1937a23..66c4958e1d 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -508,6 +508,34 @@ struct intrin_mfma_f32_32x32x64f8f6f4<32, 32> ignore = reg_a; ignore = reg_b; ignore = reg_c; +#endif + } + + template + __device__ static void Run(const f4x32_t& reg_a, const f4x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + + int32x4_t arg_a = bit_cast(reg_a); + int32x4_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0}, + reg_c.template AsType()[Number<0>{}], + 4, // cbsz + 4, // blgp + 0, // OPSEL + 0, + 0, // OPSEL + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; #endif } }; @@ -589,6 +617,40 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32> ignore = reg_b; ignore = scale_b; ignore = reg_c; +#endif + } + + template + __device__ static void Run(const f4x32_t& reg_a, + const int32_t scale_a, + const f4x32_t& reg_b, + const int32_t scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + + int32x4_t arg_a = bit_cast(reg_a); + int32x4_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0}, + reg_c.template AsType()[Number<0>{}], + 4, // cbsz + 4, // blgp + 0, // OPSEL + scale_a, + 0, // OPSEL + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; #endif } }; @@ -686,6 +748,39 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16> #endif } + template + __device__ static void Run(const f4x32_t& reg_a, + const int32_t scale_a, + const f4x32_t& reg_b, + const int32_t scale_b, + FloatC& reg_c) + { +#if defined(__gfx950__) + int32x4_t arg_a = bit_cast(reg_a); + int32x4_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0}, + reg_c.template AsType()[Number<0>{}], + 4, // cbsz + 4, // blgp + 0, // OPSEL + scale_a, + 0, // OPSEL + scale_b); +#else + ignore = reg_a; + ignore = scale_a; + ignore = reg_b; + ignore = scale_b; + ignore = reg_c; +#endif + } + template __device__ static void Run(const bf8x32_t& reg_a, const int32_t& scale_a, @@ -748,6 +843,33 @@ struct intrin_mfma_f32_16x16x128f8f6f4<16, 16> ignore = reg_a; ignore = reg_b; ignore = reg_c; +#endif + } + + template + __device__ static void Run(const f4x32_t& reg_a, const f4x32_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx950__) + int32x4_t arg_a = bit_cast(reg_a); + int32x4_t arg_b = bit_cast(reg_b); + + using arg_type = int32x8_t; + + reg_c.template AsType()(Number<0>{}) = + __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4( + arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0}, + arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0}, + reg_c.template AsType()[Number<0>{}], + 4, // cbsz + 4, // blgp + 0, // OPSEL + 0, + 0, // OPSEL + 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; #endif } }; diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 79bd717501..a6106bb146 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -470,6 +470,13 @@ struct scalar_type static constexpr index_t vector_size = 1; }; +template <> +struct scalar_type +{ + using type = f4x2_pk_t::type; + static constexpr index_t vector_size = 1; +}; + template <> struct scalar_type { diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp index 7e2482807d..c8d284a1d7 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp @@ -79,6 +79,16 @@ struct ReferenceGemm : public device::BaseOperator i4 = i4 - 8; v_a = type_convert(i4); } + else if constexpr(is_same_v) + { + // TODO: add support for ColMajor layout as well + if(k % 2 == 1) + v_a = type_convert( + f4_t(arg.a_m_k_(m, k).template unpack<>(Number<1>{}))); + else + v_a = type_convert( + f4_t(arg.a_m_k_(m, k).template unpack<>(Number<0>{}))); + } else { arg.a_element_op_(v_a, arg.a_m_k_(m, k)); @@ -95,6 +105,16 @@ struct ReferenceGemm : public device::BaseOperator i4 = i4 - 8; v_b = type_convert(i4); } + else if constexpr(is_same_v) + { + // TODO: add support for RowMajor layout as well + if(k % 2 == 1) + v_b = type_convert( + f4_t(arg.b_k_n_(k, n).template unpack<>(Number<1>{}))); + else + v_b = type_convert( + f4_t(arg.b_k_n_(k, n).template unpack<>(Number<0>{}))); + } else { arg.b_element_op_(v_b, arg.b_k_n_(k, n)); diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp index 649f130c41..e8fdcf1acd 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp @@ -89,9 +89,28 @@ struct ReferenceMXGemm : public device::BaseOperator { for(size_t k = 0; k < K; k++) { - a_m_k_scaled(m, k) = - type_convert(arg.a_m_k_(m, k)) * - type_convert(arg.a_m_kblock_scales_(m, k / SCALE_BLOCK)); + if constexpr(is_same_v) + { + // TODO: add support for ColMajor layout as well + if(k % 2 == 1) + a_m_k_scaled(m, k) = + type_convert( + f4_t(arg.a_m_k_(m, k).template unpack<>(Number<1>{}))) * + type_convert( + arg.a_m_kblock_scales_(m, k / SCALE_BLOCK)); + else + a_m_k_scaled(m, k) = + type_convert( + f4_t(arg.a_m_k_(m, k).template unpack<>(Number<0>{}))) * + type_convert( + arg.a_m_kblock_scales_(m, k / SCALE_BLOCK)); + } + else + { + a_m_k_scaled(m, k) = + type_convert(arg.a_m_k_(m, k)) * + type_convert(arg.a_m_kblock_scales_(m, k / SCALE_BLOCK)); + } } } @@ -99,9 +118,28 @@ struct ReferenceMXGemm : public device::BaseOperator { for(size_t k = 0; k < K; k++) { - b_k_n_scaled(k, n) = - type_convert(arg.b_k_n_(k, n)) * - type_convert(arg.b_kblock_n_scales_(k / SCALE_BLOCK, n)); + if constexpr(is_same_v) + { + // TODO: add support for RowMajor layout as well + if(k % 2 == 1) + b_k_n_scaled(k, n) = + type_convert( + f4_t(arg.b_k_n_(k, n).template unpack<>(Number<1>{}))) * + type_convert( + arg.b_kblock_n_scales_(k / SCALE_BLOCK, n)); + else + b_k_n_scaled(k, n) = + type_convert( + f4_t(arg.b_k_n_(k, n).template unpack<>(Number<0>{}))) * + type_convert( + arg.b_kblock_n_scales_(k / SCALE_BLOCK, n)); + } + else + { + b_k_n_scaled(k, n) = + type_convert(arg.b_k_n_(k, n)) * + type_convert(arg.b_kblock_n_scales_(k / SCALE_BLOCK, n)); + } } } diff --git a/test/mx_mfma_op/mx_mfma_op.cpp b/test/mx_mfma_op/mx_mfma_op.cpp index f65e89bb82..fddb8288a6 100644 --- a/test/mx_mfma_op/mx_mfma_op.cpp +++ b/test/mx_mfma_op/mx_mfma_op.cpp @@ -6,6 +6,8 @@ #include "mx_mfma_op.hpp" using ck::e8m0_bexp_t; +using ck::f4_t; +using ck::f4x2_pk_t; using ck::f8_t; using ck::half_t; using ck::type_convert; @@ -16,7 +18,7 @@ using ck::type_convert; * @param init - selects initialization algorithm for A and B tensors */ template -bool run_mfma_test(ck::index_t init) +bool run_mfma_km_kn_nm_test(ck::index_t init) { using ALayout = ck::tensor_layout::gemm::ColumnMajor; using BLayout = ck::tensor_layout::gemm::ColumnMajor; @@ -30,7 +32,8 @@ bool run_mfma_test(ck::index_t init) constexpr auto BLOCK_N = mfma_instr.n_per_blk; constexpr auto BLOCK_K = mfma_instr.num_input_blks * mfma_instr.k_per_blk; - const auto mfma_kernel = ck::matmul; + const auto mfma_kernel = ck:: + matmul; bool pass = true; @@ -52,15 +55,72 @@ bool run_mfma_test(ck::index_t init) TEST(MFMA, FP8MFMA16x16x128) { - auto AB_init = 4; - auto pass = run_mfma_test(AB_init); + auto AB_init = 5; + auto pass = run_mfma_km_kn_nm_test(AB_init); EXPECT_TRUE(pass); } TEST(MFMA, FP8MFMA32x32x64) +{ + auto AB_init = 5; + auto pass = run_mfma_km_kn_nm_test(AB_init); + EXPECT_TRUE(pass); +} + +/** + * @brief Run the test for the given MFMA instruction + * + * @param init - selects initialization algorithm for A and B tensors + */ +template +bool run_mfma_mk_kn_mn_test(ck::index_t init) +{ + using ALayout = ck::tensor_layout::gemm::RowMajor; + using BLayout = ck::tensor_layout::gemm::ColumnMajor; + using CLayout = ck::tensor_layout::gemm::RowMajor; + + using AccType = float; // only MFMA_F32 instructions supported + using CPUAccType = AccType; + + ck::mfma_type(mfma)> mfma_instr; + constexpr auto BLOCK_M = mfma_instr.m_per_blk; + constexpr auto BLOCK_N = mfma_instr.n_per_blk; + constexpr auto BLOCK_K = mfma_instr.num_input_blks * mfma_instr.k_per_blk; + + const auto mfma_kernel = ck:: + matmul; + + bool pass = true; + + pass = ck::mfma_test::TestMFMA{}(mfma_kernel, init); + + return pass; +} + +TEST(MFMA, FP4MFMA16x16x128) { auto AB_init = 4; - auto pass = run_mfma_test(AB_init); + auto pass = run_mfma_mk_kn_mn_test( + AB_init); + EXPECT_TRUE(pass); +} + +TEST(MFMA, FP4MFMA32x32x64) +{ + auto AB_init = 4; + auto pass = run_mfma_mk_kn_mn_test( + AB_init); EXPECT_TRUE(pass); } @@ -70,7 +130,7 @@ TEST(MFMA, FP8MFMA32x32x64) * @param init - selects initialization algorithm for A and B tensors */ template -bool run_mxmfma_test(ck::index_t init) +bool run_mxmfma_mk_kn_mn_test(ck::index_t init) { static_assert(mfma == ck::MFMA_F8F6F4::SCALE_F32_16x16x128 || mfma == ck::MFMA_F8F6F4::SCALE_F32_32x32x64, @@ -88,8 +148,18 @@ bool run_mxmfma_test(ck::index_t init) constexpr auto BLOCK_K = mfma_instr.num_input_blks * mfma_instr.k_per_blk; constexpr auto BLOCK_X = 32; // scaling vector size - const auto mx_mfma_kernel = - ck::matmul; + const auto mx_mfma_kernel = ck::matmul; bool pass = true; @@ -111,14 +181,34 @@ bool run_mxmfma_test(ck::index_t init) TEST(MXMFMA, MXFP8MFMA16x16x128) { - auto AB_init = 7; - auto pass = run_mxmfma_test(AB_init); + auto AB_init = 5; + auto pass = + run_mxmfma_mk_kn_mn_test(AB_init); EXPECT_TRUE(pass); } TEST(MXMFMA, MXFP8MFMA32x32x64) { - auto AB_init = 7; - auto pass = run_mxmfma_test(AB_init); + auto AB_init = 5; + auto pass = + run_mxmfma_mk_kn_mn_test(AB_init); + EXPECT_TRUE(pass); +} + +TEST(MXMFMA, MXFP4MFMA16x16x128) +{ + auto AB_init = 4; + auto pass = + run_mxmfma_mk_kn_mn_test( + AB_init); + EXPECT_TRUE(pass); +} + +TEST(MXMFMA, MXFP4MFMA32x32x64) +{ + auto AB_init = 4; + auto pass = + run_mxmfma_mk_kn_mn_test( + AB_init); EXPECT_TRUE(pass); } diff --git a/test/mx_mfma_op/mx_mfma_op.hpp b/test/mx_mfma_op/mx_mfma_op.hpp index d22157c3b3..9ce871cfb1 100644 --- a/test/mx_mfma_op/mx_mfma_op.hpp +++ b/test/mx_mfma_op/mx_mfma_op.hpp @@ -5,6 +5,7 @@ #include "ck/ck.hpp" +#include "ck/utility/data_type.hpp" #include "ck/library/utility/host_tensor.hpp" #include "ck/library/utility/device_memory.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" @@ -111,7 +112,7 @@ template __device__ AFragT load_A_col_major(AType const* input_ptr) { // clang-format off - // Register Mapping for 16x128: || Register Mapping for 32x64: + // Register Mapping for 16x128 for FP8: || Register Mapping for 32x64 for FP8: // Size | BLOCK_M | BLOCK_M | BLOCK_M | BLOCK_M | || Size | BLOCK_M | BLOCK_M | | // M | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || M | 0 ... 31 | 0 ... 31 | Vector | // Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element| @@ -176,13 +177,19 @@ __device__ AFragT load_A_col_major(AType const* input_ptr) auto kMinorOffset = col_major(minorStepCoord2D, BLOCK_M); auto kMajorOffset = col_major(majorStepCoord2D, BLOCK_M); - using ARawT = typename scalar_type::type; - using AScalarFragT = vector_type::type; + using ARawT = typename scalar_type::type; + using AScalarFragT = + vector_type, ck::f4x2_pk_t> ? 2 : 1)>::type; AScalarFragT fragA{}; + constexpr index_t num_chunks = + (ck::is_same_v, ck::f4x2_pk_t> ? 1 : 2); + #pragma unroll - for(int chunk = 0; chunk < 2; chunk++) + for(int chunk = 0; chunk < num_chunks; chunk++) { #pragma unroll for(uint32_t i = 0; i < chunk_size; i++) @@ -241,6 +248,28 @@ __device__ AFragT load_A_row_major(AType const* input_ptr) // Reg 7 [8:15] | K77 | K93 | K109 | K125 | v[29] || Reg 7 [8:15] | K45 | K61 | v[29] | // Reg 7 [16:23] | K78 | K94 | K110 | K126 | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] | // Reg 7 [24:31] | K79 | K95 | K111 | K127 | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] | + + // Register Mapping for 16x128 for FP4: || Register Mapping for 32x64 for FP4: + // Size | BLOCK_M | BLOCK_M | BLOCK_M | BLOCK_M | || Size | BLOCK_M | BLOCK_M | | + // M | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || M | 0 ... 31 | 0 ... 31 | Vector | + // Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element| + // Register Element |------------|-------------|------------|-------------|-----------|| Register Element |------------|-------------|--------| + // Reg 0 [0:7] | K0K1 | K32K33 | K64K65 | K96K97 | v[0] || Reg 0 [0:7] | K0K1 | K32K33 | v[0] | + // Reg 0 [8:15] | K2K3 | K34K35 | K66K67 | K98K99 | v[1] || Reg 0 [8:15] | K2K3 | K34K35 | v[1] | + // Reg 0 [16:23] | K4K5 | K36K37 | K68K69 | K100K101 | v[2] || Reg 0 [16:23] | K4K5 | K36K37 | v[2] | + // Reg 0 [24:31] | K6K7 | K38K39 | K70K71 | K102K103 | v[3] || Reg 0 [24:31] | K6K7 | K38K39 | v[3] | + // Reg 1 [0:7] | K8K9 | K40K41 | K72K73 | K104K105 | v[4] || Reg 1 [0:7] | K8K9 | K40K41 | v[4] | + // Reg 1 [8:15] | K10K11 | K42K43 | K74K75 | K106K107 | v[5] || Reg 1 [8:15] | K10K11 | K42K43 | v[5] | + // Reg 1 [16:23] | K12K13 | K44K45 | K76K77 | K108K109 | v[6] || Reg 1 [16:23] | K12K13 | K44K45 | v[6] | + // Reg 1 [24:31] | K14K15 | K46K47 | K78K79 | K110K111 | v[7] || Reg 1 [24:31] | K14K15 | K46K47 | v[7] | + // Reg 2 [0:7] | K16K17 | K48K49 | K80K81 | K112K113 | v[8] || Reg 2 [0:7] | K16K17 | K48K49 | v[8] | + // Reg 2 [8:15] | K18K19 | K50K51 | K82K83 | K114K115 | v[9] || Reg 2 [8:15] | K18K19 | K50K51 | v[9] | + // Reg 2 [16:23] | K20K21 | K52K53 | K84K85 | K116K117 | v[10] || Reg 2 [16:23] | K20K21 | K52K53 | v[10] | + // Reg 2 [24:31] | K22K23 | K54K55 | K86K87 | K118K119 | v[11] || Reg 2 [24:31] | K22K23 | K54K55 | v[11] | + // Reg 3 [0:7] | K24K25 | K56K57 | K88K89 | K120K121 | v[12] || Reg 3 [0:7] | K24K25 | K56K57 | v[12] | + // Reg 3 [8:15] | K26K27 | K58K59 | K90K91 | K122K123 | v[13] || Reg 3 [8:15] | K26K27 | K58K59 | v[13] | + // Reg 3 [16:23] | K28K29 | K60K61 | K92K93 | K124K125 | v[14] || Reg 3 [16:23] | K28K29 | K60K61 | v[14] | + // Reg 3 [24:31] | K30K31 | K62K63 | K94K95 | K126K127 | v[15] || Reg 3 [24:31] | K30K31 | K62K63 | v[15] | // clang-format on static constexpr int32_t WAVE_SIZE = 64; @@ -265,23 +294,34 @@ __device__ AFragT load_A_row_major(AType const* input_ptr) auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second; }; // BLOCK_K is a stride in A matrix - auto startOffset = row_major(startCoord2D, BLOCK_K); - // auto kMinorOffset = row_major(minorStepCoord2D, BLOCK_K); - auto kMajorOffset = row_major(majorStepCoord2D, BLOCK_K); + auto startOffset = row_major( + startCoord2D, BLOCK_K / (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1)); + // auto kMinorOffset = row_major(minorStepCoord2D, BLOCK_K / + // (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1)); + auto kMajorOffset = + row_major(majorStepCoord2D, + BLOCK_K / (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1)); using ARawT = typename scalar_type::type; using AScalarFragT = vector_type::type; + constexpr index_t num_chunks = + (ck::is_same_v, ck::f4x2_pk_t> ? 1 : 2); + union { AFragT frag; - AScalarFragT chunks[2]; + AScalarFragT chunks[num_chunks]; } fragA{}; - auto* fragPtr = reinterpret_cast(input_ptr + startOffset); - fragA.chunks[0] = *fragPtr; - fragPtr = reinterpret_cast(input_ptr + startOffset + kMajorOffset); - fragA.chunks[1] = *fragPtr; + const AScalarFragT* fragPtr; + + for(index_t chunk_idx = 0; chunk_idx < num_chunks; chunk_idx++) + { + fragPtr = reinterpret_cast(input_ptr + startOffset + + chunk_idx * kMajorOffset); + fragA.chunks[chunk_idx] = *fragPtr; + } return fragA.frag; } @@ -339,15 +379,35 @@ __device__ AFragT load_mx_A_row_major(AType const* input_ptr, // Reg 7 [8:15] | K77 | K93 | x(M,2) | K109 | K125 | x(M,3) | v[29] || Reg 7 [8:15] | K45 | K61 | v[29] | x(M,1) | // Reg 7 [16:23] | K78 | K94 | x(M,2) | K110 | K126 | x(M,3) | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] | x(M,1) | // Reg 7 [24:31] | K79 | K95 | x(M,2) | K111 | K127 | x(M,3) | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] | x(M,1) | + + // Register Mapping for 16x128 for FP4: || Register Mapping for 32x64 for FP4: + // Size | BLOCK_M | | BLOCK_M | | BLOCK_M | | BLOCK_M | | || Size | BLOCK_M | | BLOCK_M | | | + // M | 0 ... 15 | | 0 ... 15 | | 0 ... 15 | | 0 ... 15 | | Vector || M | 0 ... 31 | | 0 ... 31 | | Vector | + // Thread Id | 0 ... 15 | Scale | 16 ... 31 | Scale | 32 ... 47 | Scale | 48 ... 63 | Scale | Element || Thread Id | 0 ... 31 | Scale | 32 ... 63 | Scale | Element| + // Register Element |------------ ----------|------------- ----------|------------ ----------|------------- ----------|-----------|| Register Element |------------|----------|-------------|----------|--------| + // Reg 0 [0:7] | K0K1 | x(M,0) | K32K33 | x(M,1) | K64K65 | x(M,2) | K96K97 | x(M,3) | v[0] || Reg 0 [0:7] | K0K1 | x(M,0) | K32K33 | x(M,1) | v[0] | + // Reg 0 [8:15] | K2K3 | x(M,0) | K34K35 | x(M,1) | K66K67 | x(M,2) | K98K99 | x(M,3) | v[1] || Reg 0 [8:15] | K2K3 | x(M,0) | K34K35 | x(M,1) | v[1] | + // Reg 0 [16:23] | K4K5 | x(M,0) | K36K37 | x(M,1) | K68K69 | x(M,2) | K100K101 | x(M,3) | v[2] || Reg 0 [16:23] | K4K5 | x(M,0) | K36K37 | x(M,1) | v[2] | + // Reg 0 [24:31] | K6K7 | x(M,0) | K38K39 | x(M,1) | K70K71 | x(M,2) | K102K103 | x(M,3) | v[3] || Reg 0 [24:31] | K6K7 | x(M,0) | K38K39 | x(M,1) | v[3] | + // Reg 1 [0:7] | K8K9 | x(M,0) | K40K41 | x(M,1) | K72K73 | x(M,2) | K104K105 | x(M,3) | v[4] || Reg 1 [0:7] | K8K9 | x(M,0) | K40K41 | x(M,1) | v[4] | + // Reg 1 [8:15] | K10K11 | x(M,0) | K42K43 | x(M,1) | K74K75 | x(M,2) | K106K107 | x(M,3) | v[5] || Reg 1 [8:15] | K10K11 | x(M,0) | K42K43 | x(M,1) | v[5] | + // Reg 1 [16:23] | K12K13 | x(M,0) | K44K45 | x(M,1) | K76K77 | x(M,2) | K108K109 | x(M,3) | v[6] || Reg 1 [16:23] | K12K13 | x(M,0) | K44K45 | x(M,1) | v[6] | + // Reg 1 [24:31] | K14K15 | x(M,0) | K46K47 | x(M,1) | K78K79 | x(M,2) | K110K111 | x(M,3) | v[7] || Reg 1 [24:31] | K14K15 | x(M,0) | K46K47 | x(M,1) | v[7] | + // Reg 2 [0:7] | K16K17 | x(M,0) | K48K49 | x(M,1) | K80K81 | x(M,2) | K112K113 | x(M,3) | v[8] || Reg 2 [0:7] | K16K17 | x(M,0) | K48K49 | x(M,1) | v[8] | + // Reg 2 [8:15] | K18K19 | x(M,0) | K50K51 | x(M,1) | K82K83 | x(M,2) | K114K115 | x(M,3) | v[9] || Reg 2 [8:15] | K18K19 | x(M,0) | K50K51 | x(M,1) | v[9] | + // Reg 2 [16:23] | K20K21 | x(M,0) | K52K53 | x(M,1) | K84K85 | x(M,2) | K116K117 | x(M,3) | v[10] || Reg 2 [16:23] | K20K21 | x(M,0) | K52K53 | x(M,1) | v[10] | + // Reg 2 [24:31] | K22K23 | x(M,0) | K54K55 | x(M,1) | K86K87 | x(M,2) | K118K119 | x(M,3) | v[11] || Reg 2 [24:31] | K22K23 | x(M,0) | K54K55 | x(M,1) | v[11] | + // Reg 3 [0:7] | K24K25 | x(M,0) | K56K57 | x(M,1) | K88K89 | x(M,2) | K120K121 | x(M,3) | v[12] || Reg 3 [0:7] | K24K25 | x(M,0) | K56K57 | x(M,1) | v[12] | + // Reg 3 [8:15] | K26K27 | x(M,0) | K58K59 | x(M,1) | K90K91 | x(M,2) | K122K123 | x(M,3) | v[13] || Reg 3 [8:15] | K26K27 | x(M,0) | K58K59 | x(M,1) | v[13] | + // Reg 3 [16:23] | K28K29 | x(M,0) | K60K61 | x(M,1) | K92K93 | x(M,2) | K124K125 | x(M,3) | v[14] || Reg 3 [16:23] | K28K29 | x(M,0) | K60K61 | x(M,1) | v[14] | + // Reg 3 [24:31] | K30K31 | x(M,0) | K62K63 | x(M,1) | K94K95 | x(M,2) | K126K127 | x(M,3) | v[15] || Reg 3 [24:31] | K30K31 | x(M,0) | K62K63 | x(M,1) | v[15] | // clang-format on - static constexpr uint32_t VW = vectorSize(AFragT{}); - static_assert(VW == BLOCK_X, "Fragment size must be equal to BLOCK_X"); // To start the loading process, let's visualize in 2D coords. // Each thread will load 1 element // We need to know where they start - auto startCoord2D = std::make_pair(threadIdx.x % BLOCK_M, // Row - (threadIdx.x / BLOCK_M) * VW / BLOCK_X); // Col + auto startCoord2D = std::make_pair(threadIdx.x % BLOCK_M, // Row + (threadIdx.x / BLOCK_M)); // Col // Flatten to 1D row_major offsets. auto row_major = [](auto const& coord, auto ld) { return coord.first * ld + coord.second; }; @@ -369,7 +429,7 @@ template __device__ BFragT load_B_col_major(BType const* input_ptr) { // clang-format off - // Register Mapping for 128x16: || Register Mapping for 64x32: + // Register Mapping for 128x16 for FP8: || Register Mapping for 64x32 for FP8: // Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N | || Size | BLOCK_N | BLOCK_N | | // N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || N | 0 ... 31 | 0 ... 31 | Vector | // Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element| @@ -406,6 +466,28 @@ __device__ BFragT load_B_col_major(BType const* input_ptr) // Reg 7 [8:15] | K77 | K93 | K109 | K125 | v[29] || Reg 7 [8:15] | K45 | K61 | v[29] | // Reg 7 [16:23] | K78 | K94 | K110 | K126 | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] | // Reg 7 [24:31] | K79 | K95 | K111 | K127 | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] | + + // Register Mapping for 128x16 for FP4: || Register Mapping for 64x32 for FP4: + // Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N | || Size | BLOCK_N | BLOCK_N | | + // N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || N | 0 ... 31 | 0 ... 31 | Vector | + // Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element| + // Register Element |------------|-------------|------------|-------------|-----------|| Register Element |------------|-------------|--------| + // Reg 0 [0:7] | K0K1 | K32K33 | K64K65 | K96K97 | v[0] || Reg 0 [0:7] | K0K1 | K32K33 | v[0] | + // Reg 0 [8:15] | K2K3 | K34K35 | K66K67 | K98K99 | v[1] || Reg 0 [8:15] | K2K3 | K34K35 | v[1] | + // Reg 0 [16:23] | K4K5 | K36K37 | K68K69 | K100K101 | v[2] || Reg 0 [16:23] | K4K5 | K36K37 | v[2] | + // Reg 0 [24:31] | K6K7 | K38K39 | K70K71 | K102K103 | v[3] || Reg 0 [24:31] | K6K7 | K38K39 | v[3] | + // Reg 1 [0:7] | K8K9 | K40K41 | K72K73 | K104K105 | v[4] || Reg 1 [0:7] | K8K9 | K40K41 | v[4] | + // Reg 1 [8:15] | K10K11 | K42K43 | K74K75 | K106K107 | v[5] || Reg 1 [8:15] | K10K11 | K42K43 | v[5] | + // Reg 1 [16:23] | K12K13 | K44K45 | K76K77 | K108K109 | v[6] || Reg 1 [16:23] | K12K13 | K44K45 | v[6] | + // Reg 1 [24:31] | K14K15 | K46K47 | K78K79 | K110K111 | v[7] || Reg 1 [24:31] | K14K15 | K46K47 | v[7] | + // Reg 2 [0:7] | K16K17 | K48K49 | K80K81 | K112K113 | v[8] || Reg 2 [0:7] | K16K17 | K48K49 | v[8] | + // Reg 2 [8:15] | K18K19 | K50K51 | K82K83 | K114K115 | v[9] || Reg 2 [8:15] | K18K19 | K50K51 | v[9] | + // Reg 2 [16:23] | K20K21 | K52K53 | K84K85 | K116K117 | v[10] || Reg 2 [16:23] | K20K21 | K52K53 | v[10] | + // Reg 2 [24:31] | K22K23 | K54K55 | K86K87 | K118K119 | v[11] || Reg 2 [24:31] | K22K23 | K54K55 | v[11] | + // Reg 3 [0:7] | K24K25 | K56K57 | K88K89 | K120K121 | v[12] || Reg 3 [0:7] | K24K25 | K56K57 | v[12] | + // Reg 3 [8:15] | K26K27 | K58K59 | K90K91 | K122K123 | v[13] || Reg 3 [8:15] | K26K27 | K58K59 | v[13] | + // Reg 3 [16:23] | K28K29 | K60K61 | K92K93 | K124K125 | v[14] || Reg 3 [16:23] | K28K29 | K60K61 | v[14] | + // Reg 3 [24:31] | K30K31 | K62K63 | K94K95 | K126K127 | v[15] || Reg 3 [24:31] | K30K31 | K62K63 | v[15] | // clang-format on static constexpr int32_t WAVE_SIZE = 64; @@ -430,23 +512,34 @@ __device__ BFragT load_B_col_major(BType const* input_ptr) auto majorStepCoord2D = std::make_pair(chunk_offset, 0); // read a chunk from a col // BLOCK_K is a stride in B matrix - auto startOffset = col_major(startCoord2D, BLOCK_K); - // auto kMinorOffset = col_major(minorStepCoord2D, BLOCK_K); - auto kMajorOffset = col_major(majorStepCoord2D, BLOCK_K); + auto startOffset = col_major( + startCoord2D, BLOCK_K / (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1)); + // auto kMinorOffset = col_major(minorStepCoord2D, BLOCK_K / + // (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1)); + auto kMajorOffset = + col_major(majorStepCoord2D, + BLOCK_K / (ck::is_same_v, ck::f4x2_pk_t> ? 2 : 1)); using BRawT = typename scalar_type::type; using BScalarFragT = vector_type::type; + constexpr index_t num_chunks = + (ck::is_same_v, ck::f4x2_pk_t> ? 1 : 2); + union { BFragT frag; - BScalarFragT chunks[2]; + BScalarFragT chunks[num_chunks]; } fragB{}; - auto* fragPtr = reinterpret_cast(input_ptr + startOffset); - fragB.chunks[0] = *fragPtr; - fragPtr = reinterpret_cast(input_ptr + startOffset + kMajorOffset); - fragB.chunks[1] = *fragPtr; + const BScalarFragT* fragPtr; + + for(index_t chunk = 0; chunk < num_chunks; chunk++) + { + fragPtr = + reinterpret_cast(input_ptr + startOffset + chunk * kMajorOffset); + fragB.chunks[chunk] = *fragPtr; + } return fragB.frag; } @@ -506,15 +599,56 @@ __device__ BFragT load_mx_B_col_major(BType const* input_ptr, // Reg 7 [16:23] | K78 | K94 | x(2,N) | K110 | K126 | x(3,N) | v[30] || Reg 7 [16:23] | K46 | K62 | v[30] | x(1,N) | // Reg 7 [24:31] | K79 | K95 | x(2,N) | K111 | K127 | x(3,N) | v[31] || Reg 7 [24:31] | K47 | K63 | v[31] | x(1,N) | + // Register Mapping for 128x16 for FP4: || Register Mapping for 64x32 for FP4: + // Size | BLOCK_N | BLOCK_N | BLOCK_N | BLOCK_N | || Size | BLOCK_N | BLOCK_N | | + // N | 0 ... 15 | 0 ... 15 | 0 ... 15 | 0 ... 15 | Vector || N | 0 ... 31 | 0 ... 31 | Vector | + // Thread Id | 0 ... 15 | 16 ... 31 | 32 ... 47 | 48 ... 63 | Element || Thread Id | 0 ... 31 | 32 ... 63 | Element| + // Register Element |------------|-------------|------------|-------------|-----------|| Register Element |------------|-------------|--------| + // Reg 0 [0:7] | K0K1 | K32K33 | K64K65 | K96K97 | v[0] || Reg 0 [0:7] | K0K1 | K32K33 | v[0] | + // Reg 0 [8:15] | K2K3 | K34K35 | K66K67 | K98K99 | v[1] || Reg 0 [8:15] | K2K3 | K34K35 | v[1] | + // Reg 0 [16:23] | K4K5 | K36K37 | K68K69 | K100K101 | v[2] || Reg 0 [16:23] | K4K5 | K36K37 | v[2] | + // Reg 0 [24:31] | K6K7 | K38K39 | K70K71 | K102K103 | v[3] || Reg 0 [24:31] | K6K7 | K38K39 | v[3] | + // Reg 1 [0:7] | K8K9 | K40K41 | K72K73 | K104K105 | v[4] || Reg 1 [0:7] | K8K9 | K40K41 | v[4] | + // Reg 1 [8:15] | K10K11 | K42K43 | K74K75 | K106K107 | v[5] || Reg 1 [8:15] | K10K11 | K42K43 | v[5] | + // Reg 1 [16:23] | K12K13 | K44K45 | K76K77 | K108K109 | v[6] || Reg 1 [16:23] | K12K13 | K44K45 | v[6] | + // Reg 1 [24:31] | K14K15 | K46K47 | K78K79 | K110K111 | v[7] || Reg 1 [24:31] | K14K15 | K46K47 | v[7] | + // Reg 2 [0:7] | K16K17 | K48K49 | K80K81 | K112K113 | v[8] || Reg 2 [0:7] | K16K17 | K48K49 | v[8] | + // Reg 2 [8:15] | K18K19 | K50K51 | K82K83 | K114K115 | v[9] || Reg 2 [8:15] | K18K19 | K50K51 | v[9] | + // Reg 2 [16:23] | K20K21 | K52K53 | K84K85 | K116K117 | v[10] || Reg 2 [16:23] | K20K21 | K52K53 | v[10] | + // Reg 2 [24:31] | K22K23 | K54K55 | K86K87 | K118K119 | v[11] || Reg 2 [24:31] | K22K23 | K54K55 | v[11] | + // Reg 3 [0:7] | K24K25 | K56K57 | K88K89 | K120K121 | v[12] || Reg 3 [0:7] | K24K25 | K56K57 | v[12] | + // Reg 3 [8:15] | K26K27 | K58K59 | K90K91 | K122K123 | v[13] || Reg 3 [8:15] | K26K27 | K58K59 | v[13] | + // Reg 3 [16:23] | K28K29 | K60K61 | K92K93 | K124K125 | v[14] || Reg 3 [16:23] | K28K29 | K60K61 | v[14] | + // Reg 3 [24:31] | K30K31 | K62K63 | K94K95 | K126K127 | v[15] || Reg 3 [24:31] | K30K31 | K62K63 | v[15] | + + // Register Mapping for 128x16 for FP4: || Register Mapping for 64x32 for FP4: + // Size | BLOCK_N | | BLOCK_N | | BLOCK_N | | BLOCK_N | | || Size | BLOCK_N | | BLOCK_N | | | + // N | 0 ... 15 | | 0 ... 15 | | 0 ... 15 | | 0 ... 15 | | Vector || N | 0 ... 31 | | 0 ... 31 | | Vector | + // Thread Id | 0 ... 15 | Scale | 16 ... 31 | Scale | 32 ... 47 | Scale | 48 ... 63 | Scale | Element || Thread Id | 0 ... 31 | Scale | 32 ... 63 | Scale | Element| + // Register Element |------------ ----------|------------- ----------|------------ ----------|------------- ----------|-----------|| Register Element |------------|----------|-------------|----------|--------| + // Reg 0 [0:7] | K0K1 | x(0,N) | K32K33 | x(M,1) | K64K65 | x(M,2) | K96K97 | x(M,3) | v[0] || Reg 0 [0:7] | K0K1 | x(M,0) | K32K33 | x(M,1) | v[0] | + // Reg 0 [8:15] | K2K3 | x(0,N) | K34K35 | x(M,1) | K66K67 | x(M,2) | K98K99 | x(M,3) | v[1] || Reg 0 [8:15] | K2K3 | x(M,0) | K34K35 | x(M,1) | v[1] | + // Reg 0 [16:23] | K4K5 | x(0,N) | K36K37 | x(M,1) | K68K69 | x(M,2) | K100K101 | x(M,3) | v[2] || Reg 0 [16:23] | K4K5 | x(M,0) | K36K37 | x(M,1) | v[2] | + // Reg 0 [24:31] | K6K7 | x(0,N) | K38K39 | x(M,1) | K70K71 | x(M,2) | K102K103 | x(M,3) | v[3] || Reg 0 [24:31] | K6K7 | x(M,0) | K38K39 | x(M,1) | v[3] | + // Reg 1 [0:7] | K8K9 | x(0,N) | K40K41 | x(M,1) | K72K73 | x(M,2) | K104K105 | x(M,3) | v[4] || Reg 1 [0:7] | K8K9 | x(M,0) | K40K41 | x(M,1) | v[4] | + // Reg 1 [8:15] | K10K11 | x(0,N) | K42K43 | x(M,1) | K74K75 | x(M,2) | K106K107 | x(M,3) | v[5] || Reg 1 [8:15] | K10K11 | x(M,0) | K42K43 | x(M,1) | v[5] | + // Reg 1 [16:23] | K12K13 | x(0,N) | K44K45 | x(M,1) | K76K77 | x(M,2) | K108K109 | x(M,3) | v[6] || Reg 1 [16:23] | K12K13 | x(M,0) | K44K45 | x(M,1) | v[6] | + // Reg 1 [24:31] | K14K15 | x(0,N) | K46K47 | x(M,1) | K78K79 | x(M,2) | K110K111 | x(M,3) | v[7] || Reg 1 [24:31] | K14K15 | x(M,0) | K46K47 | x(M,1) | v[7] | + // Reg 2 [0:7] | K16K17 | x(0,N) | K48K49 | x(M,1) | K80K81 | x(M,2) | K112K113 | x(M,3) | v[8] || Reg 2 [0:7] | K16K17 | x(M,0) | K48K49 | x(M,1) | v[8] | + // Reg 2 [8:15] | K18K19 | x(0,N) | K50K51 | x(M,1) | K82K83 | x(M,2) | K114K115 | x(M,3) | v[9] || Reg 2 [8:15] | K18K19 | x(M,0) | K50K51 | x(M,1) | v[9] | + // Reg 2 [16:23] | K20K21 | x(0,N) | K52K53 | x(M,1) | K84K85 | x(M,2) | K116K117 | x(M,3) | v[10] || Reg 2 [16:23] | K20K21 | x(M,0) | K52K53 | x(M,1) | v[10] | + // Reg 2 [24:31] | K22K23 | x(0,N) | K54K55 | x(M,1) | K86K87 | x(M,2) | K118K119 | x(M,3) | v[11] || Reg 2 [24:31] | K22K23 | x(M,0) | K54K55 | x(M,1) | v[11] | + // Reg 3 [0:7] | K24K25 | x(0,N) | K56K57 | x(M,1) | K88K89 | x(M,2) | K120K121 | x(M,3) | v[12] || Reg 3 [0:7] | K24K25 | x(M,0) | K56K57 | x(M,1) | v[12] | + // Reg 3 [8:15] | K26K27 | x(0,N) | K58K59 | x(M,1) | K90K91 | x(M,2) | K122K123 | x(M,3) | v[13] || Reg 3 [8:15] | K26K27 | x(M,0) | K58K59 | x(M,1) | v[13] | + // Reg 3 [16:23] | K28K29 | x(0,N) | K60K61 | x(M,1) | K92K93 | x(M,2) | K124K125 | x(M,3) | v[14] || Reg 3 [16:23] | K28K29 | x(M,0) | K60K61 | x(M,1) | v[14] | + // Reg 3 [24:31] | K30K31 | x(0,N) | K62K63 | x(M,1) | K94K95 | x(M,2) | K126K127 | x(M,3) | v[15] || Reg 3 [24:31] | K30K31 | x(M,0) | K62K63 | x(M,1) | v[15] | // clang-format on - static constexpr uint32_t VW = vectorSize(BFragT{}); - static_assert(VW == BLOCK_X, "Fragment size must be equal to BLOCK_X"); // To start the loading process, let's visualize in 2D coords. // Each thread will load 1 element // We need to know where to start - auto startCoord2D = std::make_pair((threadIdx.x / BLOCK_N) * VW / BLOCK_X, // Row - threadIdx.x % BLOCK_N); // Col + auto startCoord2D = std::make_pair((threadIdx.x / BLOCK_N), // Row + threadIdx.x % BLOCK_N); // Col // Flatten to 1D col_major offsets. auto col_major = [](auto const& coord, auto ld) { return coord.first + coord.second * ld; }; @@ -766,15 +900,24 @@ template + int32_t BLOCK_K, + typename ALayout, + typename BLayout, + typename CLayout> __global__ void matmul(const AType* a, const BType* b, CType* c) { constexpr int WAVE_SIZE = 64; assert(threadIdx.x < WAVE_SIZE); assert(blockDim.x == 1 && blockDim.y == 1 && blockDim.z == 1); - using AFragT = vector_type::type; - using BFragT = vector_type::type; + using AFragT = + vector_type, ck::f4x2_pk_t> ? 2 : 1)>::type; + using BFragT = + vector_type, ck::f4x2_pk_t> ? 2 : 1)>::type; using CFragT = vector_type::type; using AccumFragT = vector_type; using RawAccumFragT = vector_type::type; @@ -786,10 +929,23 @@ __global__ void matmul(const AType* a, const BType* b, CType* c) auto fragAcc = AccumFragT{0}; // Load the inputs. - // A = col major, BLOCK_M x BLOCK_K - fragA = load_A_col_major(a); - // B = col major, BLOCK_K x BLOCK_N - fragB = load_B_col_major(b); + if constexpr(is_same_v) + { + fragA = load_A_row_major(a); + } + else + { + fragA = load_A_col_major(a); + } + + if constexpr(is_same_v) + { + printf("This layout is not implemented\n"); + } + else + { + fragB = load_B_col_major(b); + } // Matrix multiply-accumulate using MFMA units // Accumulation intermediate = BLOCK_M x BLOCK_N @@ -801,8 +957,14 @@ __global__ void matmul(const AType* a, const BType* b, CType* c) fragC[i] = type_convert(fragAcc.template AsType()[Number<0>{}][i]); } - auto storeC = store_C_col_major{}; - storeC(c, fragC); + if constexpr(is_same_v) + { + store_C_row_major{}(c, fragC); + } + else + { + store_C_col_major{}(c, fragC); + } } template + int32_t BLOCK_X, + typename ALayout, + typename BLayout, + typename CLayout> __global__ void matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb, CType* c) { @@ -821,8 +986,14 @@ matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb, assert(threadIdx.x < WAVE_SIZE); assert(blockDim.x == 1 && blockDim.y == 1 && blockDim.z == 1); - using AFragT = vector_type::type; - using BFragT = vector_type::type; + using AFragT = + vector_type, ck::f4x2_pk_t> ? 2 : 1)>::type; + using BFragT = + vector_type, ck::f4x2_pk_t> ? 2 : 1)>::type; using CFragT = vector_type::type; using AccumFragT = vector_type; using RawAccumFragT = vector_type::type; @@ -838,13 +1009,27 @@ matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb, auto fragXb = BScaleFragT{}; // Load the inputs. - // A = col major, BLOCK_M x BLOCK_K - fragA = load_mx_A_row_major( - a, xa, fragXa); + if constexpr(is_same_v) + { + fragA = + load_mx_A_row_major( + a, xa, fragXa); + } + else + { + printf("This layout is not implemented\n"); + } - // B = col major, BLOCK_K x BLOCK_N - fragB = load_mx_B_col_major( - b, xb, fragXb); + if constexpr(is_same_v) + { + printf("This layout is not implemented\n"); + } + else + { + fragB = + load_mx_B_col_major( + b, xb, fragXb); + } // Scaled Matrix multiply-accumulate using MFMA units // Accumulation intermediate = BLOCK_M x BLOCK_N @@ -860,8 +1045,14 @@ matmul(const AType* a, const ScaleType* xa, const BType* b, const ScaleType* xb, fragC[i] = type_convert(fragAcc.template AsType()[Number<0>{}][i]); } - auto storeC = store_C_row_major{}; - storeC(c, fragC); + if constexpr(is_same_v) + { + store_C_row_major{}(c, fragC); + } + else + { + store_C_col_major{}(c, fragC); + } } /** @@ -993,8 +1184,7 @@ struct TestMXMFMA { case 0: a_m_k.GenerateTensorValue(GeneratorTensor_1{1.0f}); - a_scales.GenerateTensorValue( - GeneratorTensor_1{ScaleType{0.015625f}}); // 1/64 + a_scales.GenerateTensorValue(GeneratorTensor_1{ScaleType{0.015625f}}); // 1/6 // NOTE: not all numbers are representable in FP8, BF8, etc. // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 16 18 20 20 20 22 24 24 24 26 28 28 28 30 32 b_n_k.GenerateTensorValue(GeneratorTensor_Sequential{}); @@ -1012,11 +1202,9 @@ struct TestMXMFMA a_m_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); a_scales.GenerateTensorValue( GeneratorTensor_2{126, 129}); // scales: {0.5, 1, 2} - b_n_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); b_scales.GenerateTensorValue(GeneratorTensor_2{126, 129}); break; - case 3: // expect small round off errors a_m_k.GenerateTensorValue(GeneratorTensor_4(0, 1)); @@ -1026,6 +1214,14 @@ struct TestMXMFMA b_scales.GenerateTensorValue( GeneratorTensor_2{126, 129}); // scales: {0.5, 1, 2} break; + case 4: + a_m_k.GenerateTensorValue(GeneratorTensor_3{-1., 1.}); + a_scales.GenerateTensorValue( + GeneratorTensor_2{126, 129}); // scales: {0.5, 1, 2} + b_n_k.GenerateTensorValue(GeneratorTensor_3{-1., 1.}); + b_scales.GenerateTensorValue( + GeneratorTensor_2{126, 129}); // scales: {0.5, 1, 2} + break; default: // all initial values are representable in FP8, BF8 a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 6}); // Z[-5,5] @@ -1207,6 +1403,11 @@ struct TestMFMA a_m_k.GenerateTensorValue(GeneratorTensor_4(-1, 3)); b_n_k.GenerateTensorValue(GeneratorTensor_4(1, 3)); break; + case 4: + // FP4 values case + a_m_k.GenerateTensorValue(GeneratorTensor_2{-4, 5}); + b_n_k.GenerateTensorValue(GeneratorTensor_2{-4, 5}); + break; default: // all initial values are representable in FP8, BF8 a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 6});