From 22a99f97e8e7d8eed03e0485035ce5dd9b9695ec Mon Sep 17 00:00:00 2001 From: Tianyuan Wu Date: Sat, 30 May 2026 01:28:48 +0000 Subject: [PATCH] [rocm-libraries] ROCm/rocm-libraries#7677 (commit 308af93) [CK_Tile] Add scale16 Support for F4 WMMA in CK_Tile ## Motivation This PR adds CK Tile support for the scale16 F4 WMMA path on gfx1250 and improves warp GEMM unit test coverage/structure for gfx1250-specific cases. ## Technical Details - Scale16 support in warp GEMM dispatch and WMMA trait plumbing: added IsScale16 plumbing to warp GEMM dispatcher path - Warp GEMM test restructuring for gfx1250: added Warp GEMM gfx1250 coverage to verify all F4 WMMA paths ## Test Plan Run ./test_ck_tile_wg_32x16x128_fp4. ## Test Result ``` ./test_ck_tile_wg_32x16x128_fp4 [----------] Global test environment tear-down [==========] 3 tests from 1 test suite ran. (1751 ms total) [ PASSED ] 3 tests. ``` ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- include/ck_tile/core/numeric/mxfp_scale.hpp | 94 ++++++++ .../gemm/warp/warp_gemm_attribute_wmma.hpp | 12 +- .../warp/warp_gemm_attribute_wmma_impl.hpp | 20 +- ...p_gemm_attribute_wmma_impl_8bit_traits.hpp | 74 ++++-- .../ops/gemm/warp/warp_gemm_dispatcher.hpp | 48 +++- .../ck_tile/ops/gemm/warp/warp_gemm_impl.hpp | 21 +- .../ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp | 16 +- test/ck_tile/warp_gemm/CMakeLists.txt | 4 + .../warp_gemm/test_f32_16x16x128_fp4.cpp | 190 +-------------- .../warp_gemm/test_f32_32x16x128_fp4.cpp | 39 +++ test/ck_tile/warp_gemm/test_gemm_util.hpp | 223 ++++++++++++++++++ 11 files changed, 507 insertions(+), 234 deletions(-) create mode 100644 test/ck_tile/warp_gemm/test_f32_32x16x128_fp4.cpp create mode 100644 test/ck_tile/warp_gemm/test_gemm_util.hpp diff --git a/include/ck_tile/core/numeric/mxfp_scale.hpp b/include/ck_tile/core/numeric/mxfp_scale.hpp index 1eb8063c02..54687e604f 100644 --- a/include/ck_tile/core/numeric/mxfp_scale.hpp +++ b/include/ck_tile/core/numeric/mxfp_scale.hpp @@ -103,7 +103,101 @@ struct Packed4Scale } }; +template +struct Packed8Scale +{ + using scale_type = ScaleType; + using raw_type = uint64_t; + using raw_scale_type = typename ScaleType::raw_type; + + static constexpr int num_pack = 8; + union + { + raw_type data_; + raw_scale_type scales_[num_pack]; // Direct byte/element access + }; + + // Constructors + CK_TILE_HOST_DEVICE constexpr Packed8Scale() = default; + CK_TILE_HOST_DEVICE constexpr Packed8Scale(raw_type val) : data_(val) {} + CK_TILE_HOST_DEVICE constexpr Packed8Scale( + float s0, float s1, float s2, float s3, float s4, float s5, float s6, float s7) + { + set_scales_from_float(s0, s1, s2, s3, s4, s5, s6, s7); + } + + CK_TILE_HOST_DEVICE constexpr Packed8Scale(ScaleType s0, + ScaleType s1, + ScaleType s2, + ScaleType s3, + ScaleType s4, + ScaleType s5, + ScaleType s6, + ScaleType s7) + { + set_scales(s0, s1, s2, s3, s4, s5, s6, s7); + } + + CK_TILE_HOST_DEVICE constexpr void set_scales_from_float( + float s0, float s1, float s2, float s3, float s4, float s5, float s6, float s7) + { + set_scales(ScaleType(s0), + ScaleType(s1), + ScaleType(s2), + ScaleType(s3), + ScaleType(s4), + ScaleType(s5), + ScaleType(s6), + ScaleType(s7)); + } + + CK_TILE_HOST_DEVICE constexpr void set_scales(ScaleType s0, + ScaleType s1, + ScaleType s2, + ScaleType s3, + ScaleType s4, + ScaleType s5, + ScaleType s6, + ScaleType s7) + { + data_ = 0; + pack_scale(s0, 7); + pack_scale(s1, 6); + pack_scale(s2, 5); + pack_scale(s3, 4); + pack_scale(s4, 3); + pack_scale(s5, 2); + pack_scale(s6, 1); + pack_scale(s7, 0); + } + + CK_TILE_HOST_DEVICE constexpr operator raw_type() const { return data_; } + CK_TILE_HOST_DEVICE constexpr raw_type& data() [[clang::lifetimebound]] { return data_; } + CK_TILE_HOST_DEVICE constexpr raw_type data() const { return data_; } + + CK_TILE_HOST_DEVICE constexpr float unpack_to_float(int i) const + { + return static_cast(unpack_scale(i)); + } + + CK_TILE_HOST_DEVICE constexpr ScaleType unpack_scale(int i) const + { + return ScaleType(scales_[i]); + } + + CK_TILE_HOST_DEVICE constexpr void pack_from_float(float scale, int i) + { + pack_scale(ScaleType(scale), i); + } + + CK_TILE_HOST_DEVICE constexpr void pack_scale(ScaleType scale, int i) + { + scales_[i] = scale.get(); + } +}; + // Type alias for e8m0_t scales using Packed4Scale_E8M0 = Packed4Scale; +using Packed8Scale_E8M0 = Packed8Scale; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp index 8cbaa9bfc8..9947915cbe 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp @@ -234,12 +234,12 @@ struct WarpGemmAttributeWmma } // c_vec += a_vec * b_vec - template + template CK_TILE_DEVICE void operator()(CVecType& c_vec, const AVecType& a_vec, - const int32_t& a_scale, + const AScaleType& a_scale, const BVecType& b_vec, - const int32_t& b_scale) const + const BScaleType& b_scale) const { if constexpr(kTransC) { @@ -253,11 +253,11 @@ struct WarpGemmAttributeWmma } // c_vec = a_vec * b_vec - template + template CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, - const int32_t& a_scale, + const AScaleType& a_scale, const BVecType& b_vec, - const int32_t& b_scale) const + const BScaleType& b_scale) const { if constexpr(kTransC) { diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp index 8aa02aba6e..8fd185cb42 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp @@ -19,6 +19,11 @@ template struct WmmaTraits; +// Tag used to select scale16 WMMA traits specializations. +struct WmmaScale16Tag +{ +}; + // Generic WMMA implementation using traits template struct WarpGemmAttributeWmmaImpl @@ -88,22 +93,22 @@ struct WarpGemmAttributeWmmaImpl Traits::template wmma_intrinsic(a_vec, b_vec, CVecType{0.f})); } - template + template CK_TILE_DEVICE void operator()(CVecType& c_vec, const AVecType& a_vec, - const int32_t& a_scale, + const AScaleType& a_scale, const BVecType& b_vec, - const int32_t& b_scale) const + const BScaleType& b_scale) const { c_vec = Traits::template wmma_intrinsic(a_vec, a_scale, b_vec, b_scale, c_vec); } // c_vec = a_vec * b_vec - template + template CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, - const int32_t& a_scale, + const AScaleType& a_scale, const BVecType& b_vec, - const int32_t& b_scale) const + const BScaleType& b_scale) const { return bit_cast(Traits::template wmma_intrinsic( a_vec, a_scale, b_vec, b_scale, CVecType{0.f})); @@ -177,6 +182,9 @@ using WarpGemmAttributeWmmaImpl_f32_32x16x128_f4 = using WarpGemmAttributeWmmaImpl_f32_32x32x128_f4 = WarpGemmAttributeWmmaImpl>; +using WarpGemmAttributeWmmaImpl_f32_32x32x128_f4_scale16 = WarpGemmAttributeWmmaImpl< + WmmaTraits>; + using WarpGemmAttributeWmmaImpl_f16_16x16x64_f8_f8 = WarpGemmAttributeWmmaImpl>; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_8bit_traits.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_8bit_traits.hpp index dcc40304f4..77dafd0956 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_8bit_traits.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_8bit_traits.hpp @@ -6,6 +6,9 @@ #include "warp_gemm_attribute_wmma_impl_base_traits.hpp" #include "warp_gemm_params.hpp" namespace ck_tile { + +struct WmmaScale16Tag; + // int8 specialization - GFX11 template <> struct WmmaTraits @@ -528,17 +531,18 @@ struct WmmaTraits } }; -template <> -struct WmmaTraits +template +struct WmmaTraitsGfx125PkFp4F32_32x32x128 : WmmaTraitsBase { - using ArchType = gfx125_t; + using ArchType = gfx125_t; + using ScaleType = std::conditional_t; template CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType& a_vec, - const int32_t& a_scale, + const ScaleType& a_scale, const BVecType& b_vec, - const int32_t& b_scale, + const ScaleType& b_scale, const CVecType& c_vec) { #ifdef __gfx125__ @@ -569,19 +573,38 @@ struct WmmaTraits const auto& b_slice = b_buffer.template get_as()[n]; auto& c_slice = c_result.template get_as()[n]; - c_slice = __builtin_amdgcn_wmma_scale_f32_32x16x128_f4( - bit_cast(a_slice), - bit_cast(b_slice), - 0, - c_slice, - 1, // OPSEL[0] - fixed to 1 for F4 - P::scale_a, // OPSEL_HI[0] - scale data type for A - a_scale, - n.value, // OPSEL[1] - select B scale (iterates over N blocks) - P::scale_b, // OPSEL_HI[1] - scale data type for B - b_scale, - 0, // NEG - 0); // NEG_HI + if constexpr(IsScale16) + { + c_slice = __builtin_amdgcn_wmma_scale16_f32_32x16x128_f4( + bit_cast(a_slice), + bit_cast(b_slice), + 0, + c_slice, + 1, // OPSEL[0] - fixed to 1 for F4 + P::scale_a, // OPSEL_HI[0] - scale data type for A + a_scale, + n.value, // OPSEL[1] - select B scale (iterates over N blocks) + P::scale_b, // OPSEL_HI[1] - scale data type for B + b_scale, + 0, // NEG + 0); // NEG_HI + } + else + { + c_slice = __builtin_amdgcn_wmma_scale_f32_32x16x128_f4( + bit_cast(a_slice), + bit_cast(b_slice), + 0, + c_slice, + 1, // OPSEL[0] - fixed to 1 for F4 + P::scale_a, // OPSEL_HI[0] - scale data type for A + a_scale, + n.value, // OPSEL[1] - select B scale (iterates over N blocks) + P::scale_b, // OPSEL_HI[1] - scale data type for B + b_scale, + 0, // NEG + 0); // NEG_HI + } }); return bit_cast(c_result); @@ -602,7 +625,8 @@ struct WmmaTraits #ifdef __gfx125__ // Pass default scale values 1.0f Packed4Scale_E8M0 pkscale(1.0f, 1.0f, 1.0f, 1.0f); - return wmma_intrinsic(a_vec, pkscale, b_vec, pkscale, c_vec); + const auto default_scale = static_cast(pkscale); + return wmma_intrinsic(a_vec, default_scale, b_vec, default_scale, c_vec); #else ck_tile::ignore = a_vec; ck_tile::ignore = b_vec; @@ -612,6 +636,18 @@ struct WmmaTraits } }; +template <> +struct WmmaTraits + : WmmaTraitsGfx125PkFp4F32_32x32x128 +{ +}; + +template <> +struct WmmaTraits + : WmmaTraitsGfx125PkFp4F32_32x32x128 +{ +}; + // f8f6f4 specialization - GFX125 enum F8F6F4OpDataTypeEnum { diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index 4027b8ed34..2e6fa605ba 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -33,6 +33,7 @@ template struct Dispatcher; @@ -178,10 +179,10 @@ template<> struct Dispatcher { using Ty #if !defined(__gfx125__) // scale mfma based f8f6f4 -template -struct Dispatcher> { using Type = WarpGemmMfma_f32_16x16x128_f8f6f4; }; -template -struct Dispatcher> { using Type = WarpGemmMfma_f32_16x16x128_f8f6f4_CTransposed; }; +template +struct Dispatcher> { using Type = WarpGemmMfma_f32_16x16x128_f8f6f4; }; +template +struct Dispatcher> { using Type = WarpGemmMfma_f32_16x16x128_f8f6f4_CTransposed; }; #endif template<> struct Dispatcher { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8<>; }; @@ -224,7 +225,7 @@ template struct Dispatcher struct Dispatcher : WmmaTag { using Type = WarpGemmWmma_f32_16x16x64_bf8_f8; }; template struct Dispatcher : WmmaTag { using Type = WarpGemmWmma_f32_32x16x128_f4; }; -template struct Dispatcher : WmmaTag { using Type = WarpGemmWmma_f32_32x32x128_f4; }; +template struct Dispatcher : WmmaTag { using Type = WarpGemmWmma_f32_32x32x128_f4; }; #if defined(__gfx125__) template struct Dispatcher : WmmaTag { using Type = WarpGemmWmma_f32_16x16x64_f8_f8; }; @@ -244,8 +245,27 @@ template<> struct Dispatcher { using Typ template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x64_bf8_bf8_CTransposed; }; #endif -template -struct Dispatcher : WmmaTag { using Type = WarpGemmWmma_f32_32x32x128_f8f6f4; }; +template +struct Dispatcher : WmmaTag +{ + using Type = WarpGemmWmma_f32_32x32x128_f8f6f4; +}; template struct Dispatcher : WmmaTag { using Type =WarpGemmWmma_f16_16x16x64_f8_f8; }; template struct Dispatcher : WmmaTag { using Type =WarpGemmWmma_f16_16x16x64_bf8_bf8; }; @@ -265,12 +285,12 @@ template struct Dispatcher + bool TransposeC, bool SA, bool SS, bool IsScale16> struct Dispatcher>>> - : Dispatcher {}; + Dispatcher>>> + : Dispatcher {}; // clang-format on } // namespace warp_gemm_dispatcher @@ -286,7 +306,8 @@ template + WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA, + bool IsScale16 = false> using WarpGemmDispatcher = typename impl::warp_gemm_dispatcher::Dispatcher< // AType, BType, @@ -298,6 +319,7 @@ using WarpGemmDispatcher = typename impl::warp_gemm_dispatcher::Dispatcher< // SwizzleA, UseStructuredSparsity, AttrNumAccessA, - AttrNumAccessB>::Type; + AttrNumAccessB, + IsScale16>::Type; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp index f0353672a0..6801f627c7 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp @@ -90,12 +90,17 @@ struct WarpGemmImpl c.get_thread_buffer().template set_as(I0, c_vec); } - template + template CK_TILE_DEVICE void operator()(CTensor& c, const ATensor& a, const BTensor& b, - const int32_t& a_scale, - const int32_t& b_scale) const + const AScaleType& a_scale, + const BScaleType& b_scale) const { static_assert(detail::is_similiar_distributed_tensor_v && detail::is_similiar_distributed_tensor_v && @@ -141,11 +146,15 @@ struct WarpGemmImpl return c; } - template + template CK_TILE_DEVICE auto operator()(const ATensor& a, const BTensor& b, - const int32_t& a_scale, - const int32_t& b_scale) const + const AScaleType& a_scale, + const BScaleType& b_scale) const { using CTensor = CWarpTensor; static_assert(detail::is_similiar_distributed_tensor_v && diff --git a/include/ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp index e7b601306f..1c522d07c1 100644 --- a/include/ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp @@ -187,12 +187,16 @@ using WarpGemmWmma_f32_32x16x128_f4 = AttrNumAccess, AttrNumAccess>>; -template -using WarpGemmWmma_f32_32x32x128_f4 = - WarpGemmImpl>; +template +using WarpGemmWmma_f32_32x32x128_f4 = WarpGemmImpl< + WarpGemmAttributeWmma, + kTransC, + AttrNumAccess, + AttrNumAccess>>; template -#include "ck_tile/host.hpp" -#include "ck_tile/host/kernel_launch.hpp" -#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" - -using namespace ck_tile; - -template -struct WGDispCase -{ - using AType = A; - using BType = B; - using AccType = Acc; - static constexpr index_t MPerWave = M; - static constexpr index_t NPerWave = N; - static constexpr index_t KPerWave = K; - static constexpr bool kTransposeC = TransposeC; - static constexpr bool kSwizzleA = SwizzleA; - static constexpr bool kUSS = UseStructuredSparsity; - static constexpr WGAttrNumAccessEnum kNA = NA; -}; +#include "test_gemm_util.hpp" +#include "gtest/gtest.h" using WGDispatcherTypesList = - ::testing::Types>; + ::testing::Types>; -template -struct WarpGemmKernel -{ - static constexpr int kBlockSize = 64; - __device__ void operator()(void* A, void* B, void* C, void* ScaleA, void* ScaleB) const - { - using WarpGemm = ck_tile::WarpGemmDispatcher; - // A: [M,K] row-major (packed) - const auto a_view = ck_tile::make_naive_tensor_view( - static_cast(A), - ck_tile::make_tuple(M, K), - ck_tile::make_tuple(K, ck_tile::number<1>{}), - ck_tile::number{}, - ck_tile::number<1>{}); - // B: expose as logical [N,K] with strides (1, N) over the original row-major [K,N] buffer - const auto b_view = ck_tile::make_naive_tensor_view( - static_cast(B), - ck_tile::make_tuple(N, K), - ck_tile::make_tuple(K, ck_tile::number<1>{}), - ck_tile::number{}, - ck_tile::number<1>{}); - // C: [M,N] row-major (packed) - const auto c_view = ck_tile::make_naive_tensor_view( - static_cast(C), - ck_tile::make_tuple(M, N), - ck_tile::make_tuple(N, ck_tile::number<1>{}), - ck_tile::number{}, - ck_tile::number<1>{}); - - using AWarpTensor = typename WarpGemm::AWarpTensor; - using BWarpTensor = typename WarpGemm::BWarpTensor; - using CWarpTensor = typename WarpGemm::CWarpTensor; - - constexpr auto a_len = AWarpTensor::get_tile_distribution().get_lengths(); - constexpr auto b_len = BWarpTensor::get_tile_distribution().get_lengths(); - constexpr auto c_len = CWarpTensor::get_tile_distribution().get_lengths(); - - auto a_win = ck_tile::make_tile_window( - a_view, a_len, ck_tile::make_multi_index(0, 0), AWarpTensor::get_tile_distribution()); - auto b_win = ck_tile::make_tile_window( - b_view, b_len, ck_tile::make_multi_index(0, 0), BWarpTensor::get_tile_distribution()); - auto c_win = ck_tile::make_tile_window( - c_view, c_len, ck_tile::make_multi_index(0, 0), CWarpTensor::get_tile_distribution()); - - AWarpTensor a_tile; - BWarpTensor b_tile; - ck_tile::load_tile(a_tile, a_win); - ck_tile::load_tile(b_tile, b_win); - - auto scale_a = static_cast(static_cast(ScaleA)[0].get()); - auto scale_b = static_cast(static_cast(ScaleB)[0].get()); - - auto c_tile = - WarpGemm{}.template operator(), OpSelB<0>>(a_tile, b_tile, scale_a, scale_b); - - ck_tile::store_tile(c_win, c_tile); - } -}; - -template -static void RunWarpGemmCase(const ck_tile::HostTensor& A, - const ck_tile::HostTensor& B, - const ck_tile::HostTensor& ScaleA, - const ck_tile::HostTensor& ScaleB, - ck_tile::HostTensor& C) -{ - ck_tile::DeviceMem Ad(A), Bd(B), Cd(C), SAd(ScaleA), SBd(ScaleB); - dim3 grid(1), block{64}; - - using Kernel = WarpGemmKernel; - - (void)ck_tile::launch_kernel(ck_tile::stream_config{nullptr, true, 0, 0, 1}, - ck_tile::make_kernel(Kernel{}, - grid, - block, - 0, - Ad.GetDeviceBuffer(), - Bd.GetDeviceBuffer(), - Cd.GetDeviceBuffer(), - SAd.GetDeviceBuffer(), - SBd.GetDeviceBuffer())); - - Cd.FromDevice(C.mData.data()); -} - -template +template class WGRuntimeTest : public ::testing::Test { }; @@ -156,38 +22,6 @@ TYPED_TEST_SUITE(WGRuntimeTest, WGDispatcherTypesList); TYPED_TEST(WGRuntimeTest, Compare_Dispatcher_MakeWG) { - using Case = TypeParam; - - using AType = typename Case::AType; - using BType = typename Case::BType; - using CType = typename Case::AccType; - using ck_tile::e8m0_t; - - constexpr index_t M = Case::MPerWave; - constexpr index_t N = Case::NPerWave; - constexpr index_t K = Case::KPerWave; - - auto ScaleA = e8m0_t{2.f}; - auto ScaleB = e8m0_t{4.f}; - - ck_tile::HostTensor A({M, K}); - ck_tile::HostTensor B({N, K}); - ck_tile::HostTensor C({M, N}); - ck_tile::HostTensor sA({M, 1}); - ck_tile::HostTensor sB({N, 1}); - - ck_tile::FillUniformDistribution{-5.f, 5.f}(A); - ck_tile::FillUniformDistribution{-5.f, 5.f}(B); - C.SetZero(); - ck_tile::FillConstant{ScaleA}(sA); - ck_tile::FillConstant{ScaleB}(sB); - - RunWarpGemmCase(A, B, sA, sB, C); - - ck_tile::HostTensor C_ref({M, N}); - C_ref.SetZero(); - ck_tile::reference_mx_gemm( - A, B.transpose(), C_ref, sA, sB.transpose()); - - EXPECT_TRUE(ck_tile::check_err(C, C_ref, "Warp gemm result error.")); + ck_tile::test::warp_gemm:: + RunCompareDispatcherAndReference(); } diff --git a/test/ck_tile/warp_gemm/test_f32_32x16x128_fp4.cpp b/test/ck_tile/warp_gemm/test_f32_32x16x128_fp4.cpp new file mode 100644 index 0000000000..be394ee3a3 --- /dev/null +++ b/test/ck_tile/warp_gemm/test_f32_32x16x128_fp4.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_util.hpp" +#include "gtest/gtest.h" + +using WGDispatcherTypesList = + ::testing::Types>; + +template +class WGRuntimeTest : public ::testing::Test +{ +}; + +TYPED_TEST_SUITE(WGRuntimeTest, WGDispatcherTypesList); + +TYPED_TEST(WGRuntimeTest, Compare_Dispatcher_MakeWG_NonScaled) +{ + ck_tile::test::warp_gemm:: + RunCompareDispatcherAndReference(); +} + +TYPED_TEST(WGRuntimeTest, Compare_Dispatcher_MakeWG_Scale16) +{ + ck_tile::test::warp_gemm:: + RunCompareDispatcherAndReference(); +} + +TYPED_TEST(WGRuntimeTest, Compare_Dispatcher_MakeWG_Scale32) +{ + ck_tile::test::warp_gemm:: + RunCompareDispatcherAndReference(); +} diff --git a/test/ck_tile/warp_gemm/test_gemm_util.hpp b/test/ck_tile/warp_gemm/test_gemm_util.hpp new file mode 100644 index 0000000000..dcd1d3f342 --- /dev/null +++ b/test/ck_tile/warp_gemm/test_gemm_util.hpp @@ -0,0 +1,223 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include "ck_tile/host.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/core/numeric/mxfp_scale.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" + +namespace ck_tile::test::warp_gemm { + +template +struct WGDispCase +{ + using AType = A; + using BType = B; + using AccType = Acc; + static constexpr bool kTransposeC = TransposeC; + static constexpr bool kSwizzleA = SwizzleA; + static constexpr bool kUSS = UseStructuredSparsity; + static constexpr WGAttrNumAccessEnum kNA = NA; +}; + +template +struct WarpGemmKernel +{ + static constexpr int kBlockSize = 64; + __device__ void operator()(void* A, void* B, void* C, void* ScaleA, void* ScaleB) const + { + using WarpGemm = ck_tile::WarpGemmDispatcher; + + const auto a_view = ck_tile::make_naive_tensor_view( + static_cast(A), + ck_tile::make_tuple(MPerWave, KPerWave), + ck_tile::make_tuple(KPerWave, ck_tile::number<1>{}), + ck_tile::number{}, + ck_tile::number<1>{}); + const auto b_view = ck_tile::make_naive_tensor_view( + static_cast(B), + ck_tile::make_tuple(NPerWave, KPerWave), + ck_tile::make_tuple(KPerWave, ck_tile::number<1>{}), + ck_tile::number{}, + ck_tile::number<1>{}); + const auto c_view = ck_tile::make_naive_tensor_view( + static_cast(C), + ck_tile::make_tuple(MPerWave, NPerWave), + ck_tile::make_tuple(NPerWave, ck_tile::number<1>{}), + ck_tile::number{}, + ck_tile::number<1>{}); + + using AWarpTensor = typename WarpGemm::AWarpTensor; + using BWarpTensor = typename WarpGemm::BWarpTensor; + using CWarpTensor = typename WarpGemm::CWarpTensor; + + constexpr auto a_len = AWarpTensor::get_tile_distribution().get_lengths(); + constexpr auto b_len = BWarpTensor::get_tile_distribution().get_lengths(); + constexpr auto c_len = CWarpTensor::get_tile_distribution().get_lengths(); + + auto a_win = ck_tile::make_tile_window( + a_view, a_len, ck_tile::make_multi_index(0, 0), AWarpTensor::get_tile_distribution()); + auto b_win = ck_tile::make_tile_window( + b_view, b_len, ck_tile::make_multi_index(0, 0), BWarpTensor::get_tile_distribution()); + auto c_win = ck_tile::make_tile_window( + c_view, c_len, ck_tile::make_multi_index(0, 0), CWarpTensor::get_tile_distribution()); + + AWarpTensor a_tile; + BWarpTensor b_tile; + ck_tile::load_tile(a_tile, a_win); + ck_tile::load_tile(b_tile, b_win); + + const auto c_tile = [&]() { + if constexpr(UseScale) + { + using ScaleType = std::conditional_t; + const auto scale_a = static_cast(ScaleA)[0]; + const auto scale_b = static_cast(ScaleB)[0]; + const auto packed_scale_a = [&]() -> ScaleType { + if constexpr(IsScale16) + { + Packed8Scale_E8M0 pkscale( + scale_a, scale_a, scale_a, scale_a, scale_a, scale_a, scale_a, scale_a); + return static_cast(pkscale); + } + else + { + Packed4Scale_E8M0 pkscale(scale_a, scale_a, scale_a, scale_a); + return static_cast(pkscale); + } + }(); + const auto packed_scale_b = [&]() -> ScaleType { + if constexpr(IsScale16) + { + Packed8Scale_E8M0 pkscale( + scale_b, scale_b, scale_b, scale_b, scale_b, scale_b, scale_b, scale_b); + return static_cast(pkscale); + } + else + { + Packed4Scale_E8M0 pkscale(scale_b, scale_b, scale_b, scale_b); + return static_cast(pkscale); + } + }(); + return WarpGemm{}.template operator(), OpSelB<0>>( + a_tile, b_tile, packed_scale_a, packed_scale_b); + } + else + { + ck_tile::ignore = ScaleA; + ck_tile::ignore = ScaleB; + return WarpGemm{}.template operator(), OpSelB<0>>(a_tile, b_tile); + } + }(); + + ck_tile::store_tile(c_win, c_tile); + } +}; + +template +void RunWarpGemmCase(const ck_tile::HostTensor& A, + const ck_tile::HostTensor& B, + const ck_tile::HostTensor& ScaleA, + const ck_tile::HostTensor& ScaleB, + ck_tile::HostTensor& C) +{ + ck_tile::DeviceMem Ad(A), Bd(B), Cd(C), SAd(ScaleA), SBd(ScaleB); + dim3 grid(1), block{64}; + + (void)ck_tile::launch_kernel( + ck_tile::stream_config{nullptr, true, 0, 0, 1}, + ck_tile::make_kernel( + WarpGemmKernel{}, + grid, + block, + 0, + Ad.GetDeviceBuffer(), + Bd.GetDeviceBuffer(), + Cd.GetDeviceBuffer(), + SAd.GetDeviceBuffer(), + SBd.GetDeviceBuffer())); + + Cd.FromDevice(C.mData.data()); +} + +template +void RunCompareDispatcherAndReference() +{ + using AType = typename Case::AType; + using BType = typename Case::BType; + using CType = typename Case::AccType; + + constexpr index_t M = MPerWave; + constexpr index_t N = NPerWave; + constexpr index_t K = KPerWave; + + const auto ScaleA = ck_tile::e8m0_t{2.f}; + const auto ScaleB = ck_tile::e8m0_t{4.f}; + + ck_tile::HostTensor A({M, K}); + ck_tile::HostTensor B({N, K}); + ck_tile::HostTensor C({M, N}); + ck_tile::HostTensor sA({M, 1}); + ck_tile::HostTensor sB({N, 1}); + + ck_tile::FillUniformDistribution{-5.f, 5.f}(A); + ck_tile::FillUniformDistribution{-5.f, 5.f}(B); + C.SetZero(); + ck_tile::FillConstant{ScaleA}(sA); + ck_tile::FillConstant{ScaleB}(sB); + + RunWarpGemmCase(A, B, sA, sB, C); + + ck_tile::HostTensor C_ref({M, N}); + C_ref.SetZero(); + + if constexpr(UseScale) + { + ck_tile::reference_mx_gemm( + A, B.transpose(), C_ref, sA, sB.transpose()); + } + else + { + ck_tile::reference_gemm(A, B.transpose(), C_ref); + } + + EXPECT_TRUE(ck_tile::check_err(C, C_ref, "Warp gemm result error.")); +} + +} // namespace ck_tile::test::warp_gemm