diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp index 9947915cbe..360408f319 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp @@ -154,6 +154,19 @@ struct CTransposedWarpDstrEncodingTrait typename Impl::kCTYs2RHsMinor>; }; +namespace detail { +template +struct mx_type_enable_or_void +{ + using type = void; +}; +template +struct mx_type_enable_or_void> +{ + using type = typename T::MXTypeEnableType; +}; +} // namespace detail + template ; - // When kTransC is true and A/B types differ, we need an impl with swapped types - using TransposedImpl = - std::conditional_t, - WarpGemmAttributeWmmaImpl>, - Impl>; + // When kTransC is true and A/B types differ, we need an impl with swapped types. + // Propagate MXTypeEnable (e.g., WmmaScale16Tag) so the transposed impl uses the + // same WmmaTraits specialization family. + using TransposedImpl = std::conditional_t< + kTransC && !std::is_same_v, + WarpGemmAttributeWmmaImpl< + WmmaTraits::type>>, + Impl>; using ADataType = typename Impl::ADataType; using BDataType = typename Impl::BDataType; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp index 6f38199828..2883acdf5a 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl.hpp @@ -204,6 +204,13 @@ template using WarpGemmAttributeWmmaImpl_f32_16x16x128_f8f6f4 = WarpGemmAttributeWmmaImpl>; +// WmmaScale16Tag (declared above) is passed as MXTypeEnable to WmmaTraits to select scale16 +// specializations. These override kAK1PerLane=16 (-> sequence<4,2,16>) and use int64_t scales +// for V_WMMA_SCALE16_F32_16X16X128_F8F6F4, vs the default layout / int32_t. +template +using WarpGemmAttributeWmmaImpl_f32_16x16x128_f8f6f4_scale16 = WarpGemmAttributeWmmaImpl< + WmmaTraits>; + template using WarpGemmAttributeWmmaImpl_f32_32x32x128_f8f6f4 = WarpGemmAttributeWmmaImpl>; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_8bit_traits.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_8bit_traits.hpp index 77dafd0956..3631b94d32 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_8bit_traits.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma_impl_8bit_traits.hpp @@ -427,6 +427,82 @@ struct WmmaTraits } }; +// f8f6f4 specialization - GFX125 +enum F8F6F4OpDataTypeEnum +{ + E4M3, // 0x0 + E5M2, // 0x1 + E2M3, // 0x2 + E3M2, // 0x3 + E2M1, // 0x4 +}; + +// Traits for MX data types used in f8f6f4 intrinsics +template +struct MXDataTypeTrait; + +template <> +struct MXDataTypeTrait +{ + static constexpr F8F6F4OpDataTypeEnum OpDataType = F8F6F4OpDataTypeEnum::E4M3; + using VecType = int32x16_t; + + CK_TILE_DEVICE static int32x16_t to_wmma_vec(const int32x16_t& vec) { return vec; } +}; + +template <> +struct MXDataTypeTrait +{ + static constexpr F8F6F4OpDataTypeEnum OpDataType = F8F6F4OpDataTypeEnum::E5M2; + using VecType = int32x16_t; + + CK_TILE_DEVICE static int32x16_t to_wmma_vec(const int32x16_t& vec) { return vec; } +}; + +template <> +struct MXDataTypeTrait +{ + static constexpr F8F6F4OpDataTypeEnum OpDataType = F8F6F4OpDataTypeEnum::E2M1; + using VecType = int32x8_t; + + CK_TILE_DEVICE static int32x16_t to_wmma_vec(const int32x8_t& vec) + { + return int32x16_t{ + vec[0], vec[1], vec[2], vec[3], vec[4], vec[5], vec[6], vec[7], 0, 0, 0, 0, 0, 0, 0, 0}; + } +}; + +// pk_fp6x16_t (legacy): 16 fp6 e2m3 values packed into 3 int32 (96 bits). +// At 16x16x128 each lane holds 64 fp6 elements = 4 packs = 12 int32 +// (f6x16xN_tt<4, f6_kind::fp6>, whose storage is int32_t data[12]), +// padded with 4 zero lanes to fit the 16-wide f8f6f4 wmma input. +template <> +struct MXDataTypeTrait +{ + static constexpr F8F6F4OpDataTypeEnum OpDataType = F8F6F4OpDataTypeEnum::E2M3; + using VecType = f6x16xN_tt<4, f6_kind::fp6>; + + CK_TILE_DEVICE static int32x16_t to_wmma_vec(const f6x16xN_tt<4, f6_kind::fp6>& vec) + { + return int32x16_t{vec.data[0], + vec.data[1], + vec.data[2], + vec.data[3], + vec.data[4], + vec.data[5], + vec.data[6], + vec.data[7], + vec.data[8], + vec.data[9], + vec.data[10], + vec.data[11], + 0, + 0, + 0, + 0}; + } +}; + template <> struct WmmaTraits : WmmaTraitsBase @@ -508,6 +584,232 @@ struct WmmaTraits } }; +// scale16 specializations: fp8xfp8, bf8xbf8, fp8xbf8, bf8xfp8 +// Override kAK1PerLane/kBK1PerLane to 16 for scale16 register layout -> sequence<4,2,16> +template <> +struct WmmaTraits + : WmmaTraitsBase +{ + using ArchType = gfx125_t; + using MXTypeEnableType = WmmaScale16Tag; + + static constexpr index_t kAK1PerLane = 16; + static constexpr index_t kAK0PerLane = kK / (kAK1PerLane * kABKLane); + static constexpr index_t kBK1PerLane = 16; + static constexpr index_t kBK0PerLane = kK / (kBK1PerLane * kABKLane); + + template + CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType&, const BVecType&, const CVecType&) + { + static_assert(sizeof...(Params) < 0, "scale16 WmmaTraits requires int64_t scale arguments"); + return CVecType{0}; + } + + template + CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType& a_vec, + const int64_t& a_scale, + const BVecType& b_vec, + const int64_t& b_scale, + const CVecType& c_vec) + { +#ifdef __gfx125__ + using P = WarpGemmParamsParser; + using ATraits = MXDataTypeTrait; + using BTraits = MXDataTypeTrait; + return __builtin_amdgcn_wmma_scale16_f32_16x16x128_f8f6f4( + ATraits::OpDataType, + ATraits::to_wmma_vec(bit_cast(a_vec)), + BTraits::OpDataType, + BTraits::to_wmma_vec(bit_cast(b_vec)), + 0, + bit_cast(c_vec), + P::op_sel_a, + P::scale_a, + a_scale, + P::op_sel_b, + P::scale_b, + b_scale, + 0, + 0); +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = a_scale; + ck_tile::ignore = b_vec; + ck_tile::ignore = b_scale; + ck_tile::ignore = c_vec; + return CVecType{0}; +#endif + } +}; + +template <> +struct WmmaTraits + : WmmaTraitsBase +{ + using ArchType = gfx125_t; + using MXTypeEnableType = WmmaScale16Tag; + + static constexpr index_t kAK1PerLane = 16; + static constexpr index_t kAK0PerLane = kK / (kAK1PerLane * kABKLane); + static constexpr index_t kBK1PerLane = 16; + static constexpr index_t kBK0PerLane = kK / (kBK1PerLane * kABKLane); + + template + CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType&, const BVecType&, const CVecType&) + { + static_assert(sizeof...(Params) < 0, "scale16 WmmaTraits requires int64_t scale arguments"); + return CVecType{0}; + } + + template + CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType& a_vec, + const int64_t& a_scale, + const BVecType& b_vec, + const int64_t& b_scale, + const CVecType& c_vec) + { +#ifdef __gfx125__ + using P = WarpGemmParamsParser; + using ATraits = MXDataTypeTrait; + using BTraits = MXDataTypeTrait; + return __builtin_amdgcn_wmma_scale16_f32_16x16x128_f8f6f4( + ATraits::OpDataType, + ATraits::to_wmma_vec(bit_cast(a_vec)), + BTraits::OpDataType, + BTraits::to_wmma_vec(bit_cast(b_vec)), + 0, + bit_cast(c_vec), + P::op_sel_a, + P::scale_a, + a_scale, + P::op_sel_b, + P::scale_b, + b_scale, + 0, + 0); +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = a_scale; + ck_tile::ignore = b_vec; + ck_tile::ignore = b_scale; + ck_tile::ignore = c_vec; + return CVecType{0}; +#endif + } +}; + +template <> +struct WmmaTraits + : WmmaTraitsBase +{ + using ArchType = gfx125_t; + using MXTypeEnableType = WmmaScale16Tag; + + static constexpr index_t kAK1PerLane = 16; + static constexpr index_t kAK0PerLane = kK / (kAK1PerLane * kABKLane); + static constexpr index_t kBK1PerLane = 16; + static constexpr index_t kBK0PerLane = kK / (kBK1PerLane * kABKLane); + + template + CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType&, const BVecType&, const CVecType&) + { + static_assert(sizeof...(Params) < 0, "scale16 WmmaTraits requires int64_t scale arguments"); + return CVecType{0}; + } + + template + CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType& a_vec, + const int64_t& a_scale, + const BVecType& b_vec, + const int64_t& b_scale, + const CVecType& c_vec) + { +#ifdef __gfx125__ + using P = WarpGemmParamsParser; + using ATraits = MXDataTypeTrait; + using BTraits = MXDataTypeTrait; + return __builtin_amdgcn_wmma_scale16_f32_16x16x128_f8f6f4( + ATraits::OpDataType, + ATraits::to_wmma_vec(bit_cast(a_vec)), + BTraits::OpDataType, + BTraits::to_wmma_vec(bit_cast(b_vec)), + 0, + bit_cast(c_vec), + P::op_sel_a, + P::scale_a, + a_scale, + P::op_sel_b, + P::scale_b, + b_scale, + 0, + 0); +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = a_scale; + ck_tile::ignore = b_vec; + ck_tile::ignore = b_scale; + ck_tile::ignore = c_vec; + return CVecType{0}; +#endif + } +}; + +template <> +struct WmmaTraits + : WmmaTraitsBase +{ + using ArchType = gfx125_t; + using MXTypeEnableType = WmmaScale16Tag; + + static constexpr index_t kAK1PerLane = 16; + static constexpr index_t kAK0PerLane = kK / (kAK1PerLane * kABKLane); + static constexpr index_t kBK1PerLane = 16; + static constexpr index_t kBK0PerLane = kK / (kBK1PerLane * kABKLane); + + template + CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType&, const BVecType&, const CVecType&) + { + static_assert(sizeof...(Params) < 0, "scale16 WmmaTraits requires int64_t scale arguments"); + return CVecType{0}; + } + + template + CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType& a_vec, + const int64_t& a_scale, + const BVecType& b_vec, + const int64_t& b_scale, + const CVecType& c_vec) + { +#ifdef __gfx125__ + using P = WarpGemmParamsParser; + using ATraits = MXDataTypeTrait; + using BTraits = MXDataTypeTrait; + return __builtin_amdgcn_wmma_scale16_f32_16x16x128_f8f6f4( + ATraits::OpDataType, + ATraits::to_wmma_vec(bit_cast(a_vec)), + BTraits::OpDataType, + BTraits::to_wmma_vec(bit_cast(b_vec)), + 0, + bit_cast(c_vec), + P::op_sel_a, + P::scale_a, + a_scale, + P::op_sel_b, + P::scale_b, + b_scale, + 0, + 0); +#else + ck_tile::ignore = a_vec; + ck_tile::ignore = a_scale; + ck_tile::ignore = b_vec; + ck_tile::ignore = b_scale; + ck_tile::ignore = c_vec; + return CVecType{0}; +#endif + } +}; + // 32x16x128 f4 specialization - GFX125 template <> struct WmmaTraits @@ -648,82 +950,6 @@ struct WmmaTraits -struct MXDataTypeTrait; - -template <> -struct MXDataTypeTrait -{ - static constexpr F8F6F4OpDataTypeEnum OpDataType = F8F6F4OpDataTypeEnum::E4M3; - using VecType = int32x16_t; - - CK_TILE_DEVICE static int32x16_t to_wmma_vec(const int32x16_t& vec) { return vec; } -}; - -template <> -struct MXDataTypeTrait -{ - static constexpr F8F6F4OpDataTypeEnum OpDataType = F8F6F4OpDataTypeEnum::E5M2; - using VecType = int32x16_t; - - CK_TILE_DEVICE static int32x16_t to_wmma_vec(const int32x16_t& vec) { return vec; } -}; - -template <> -struct MXDataTypeTrait -{ - static constexpr F8F6F4OpDataTypeEnum OpDataType = F8F6F4OpDataTypeEnum::E2M1; - using VecType = int32x8_t; - - CK_TILE_DEVICE static int32x16_t to_wmma_vec(const int32x8_t& vec) - { - return int32x16_t{ - vec[0], vec[1], vec[2], vec[3], vec[4], vec[5], vec[6], vec[7], 0, 0, 0, 0, 0, 0, 0, 0}; - } -}; - -// pk_fp6x16_t (legacy): 16 fp6 e2m3 values packed into 3 int32 (96 bits). -// At 16x16x128 each lane holds 64 fp6 elements = 4 packs = 12 int32 -// (f6x16xN_tt<4, f6_kind::fp6>, whose storage is int32_t data[12]), -// padded with 4 zero lanes to fit the 16-wide f8f6f4 wmma input. -template <> -struct MXDataTypeTrait -{ - static constexpr F8F6F4OpDataTypeEnum OpDataType = F8F6F4OpDataTypeEnum::E2M3; - using VecType = f6x16xN_tt<4, f6_kind::fp6>; - - CK_TILE_DEVICE static int32x16_t to_wmma_vec(const f6x16xN_tt<4, f6_kind::fp6>& vec) - { - return int32x16_t{vec.data[0], - vec.data[1], - vec.data[2], - vec.data[3], - vec.data[4], - vec.data[5], - vec.data[6], - vec.data[7], - vec.data[8], - vec.data[9], - vec.data[10], - vec.data[11], - 0, - 0, - 0, - 0}; - } -}; - // Unified WmmaTraits for f8f6f4 combinations template struct WmmaTraits< diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index b77f09fa9f..930b9294df 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -240,6 +240,9 @@ template struct Dispatcher struct Dispatcher : WmmaTag { using Type = WarpGemmWmma_f32_16x16x128_f8f6f4; }; + +// F8F6F4 Scale16 (IsScale16=true) +template struct Dispatcher : WmmaTag { using Type = WarpGemmWmma_f32_16x16x128_f8f6f4_scale16; }; #else template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8<>; }; template<> struct Dispatcher { using Type = WarpGemmMfma_f32_16x16x64_bf8_bf8; }; diff --git a/include/ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp index 224937dc6d..26540a4642 100644 --- a/include/ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp @@ -216,6 +216,17 @@ using WarpGemmWmma_f32_16x16x128_f8f6f4 = AttrNumAccessA, AttrNumAccessB>>; +template +using WarpGemmWmma_f32_16x16x128_f8f6f4_scale16 = WarpGemmImpl< + WarpGemmAttributeWmma, + kTransC, + AttrNumAccessA, + AttrNumAccessB>>; + template +#include "ck_tile/host.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" + +using namespace ck_tile; + +template +CK_TILE_DEVICE static constexpr auto MakeScaleDistribution() +{ + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<0>>{}); +} + +// Scale16 kernel using the WarpGemm wrapper (Layer 4). +// WarpGemmWmma_f32_16x16x128_f8f6f4_scale16 produces AWarpDstr with sequence<4,2,16>, +// matching the hardware's 1-scale-byte-per-16-K-elements register layout. +template +struct WarpGemmScale16Kernel +{ + static constexpr int kBlockSize = 32; + static constexpr index_t ScaleBlockK = 16; + static constexpr index_t NumScales = K / ScaleBlockK; + + __device__ void operator()(void* A, void* B, void* C, void* ScaleA, void* ScaleB) const + { + const auto a_view = + make_naive_tensor_view(static_cast(A), + make_tuple(M, K), + make_tuple(K, number<1>{}), + number{}, + number<1>{}); + const auto b_view = + make_naive_tensor_view(static_cast(B), + make_tuple(N, K), + make_tuple(K, number<1>{}), + number{}, + number<1>{}); + const auto c_view = + make_naive_tensor_view(static_cast(C), + make_tuple(M, N), + make_tuple(N, number<1>{}), + number{}, + number<1>{}); + const auto sa_view = + make_naive_tensor_view(static_cast(ScaleA), + make_tuple(M, NumScales), + make_tuple(NumScales, number<1>{}), + number{}, + number<1>{}); + const auto sb_view = + make_naive_tensor_view(static_cast(ScaleB), + make_tuple(N, NumScales), + make_tuple(NumScales, number<1>{}), + number{}, + number<1>{}); + + using WarpGemm = WarpGemmDispatcher; + constexpr auto a_dstr = typename WarpGemm::AWarpDstr{}; + constexpr auto b_dstr = typename WarpGemm::BWarpDstr{}; + constexpr auto c_dstr = typename WarpGemm::CWarpDstr{}; + constexpr auto scale_dstr = MakeScaleDistribution(); + + auto a_win = make_tile_window( + a_view, make_tuple(number{}, number{}), make_multi_index(0, 0), a_dstr); + auto b_win = make_tile_window( + b_view, make_tuple(number{}, number{}), make_multi_index(0, 0), b_dstr); + auto c_win = make_tile_window( + c_view, make_tuple(number{}, number{}), make_multi_index(0, 0), c_dstr); + auto sa_win = make_tile_window(sa_view, + make_tuple(number{}, number{}), + make_multi_index(0, 0), + scale_dstr); + auto sb_win = make_tile_window(sb_view, + make_tuple(number{}, number{}), + make_multi_index(0, 0), + scale_dstr); + + auto a_tile = load_tile(a_win); + auto b_tile = load_tile(b_win); + auto sa_tile = load_tile(sa_win); + auto sb_tile = load_tile(sb_win); + + int64_t scale_a = + bit_cast(sa_tile.get_thread_buffer() + .template get_as>()[number<0>{}]); + int64_t scale_b = + bit_cast(sb_tile.get_thread_buffer() + .template get_as>()[number<0>{}]); + + auto c_tile = WarpGemm{}(a_tile, b_tile, scale_a, scale_b); + store_tile(c_win, c_tile); + } +}; + +template +struct WGDispCase +{ + using AType = A; + using BType = B; + using AccType = Acc; + static constexpr index_t MPerWave = M; + static constexpr index_t NPerWave = N; + static constexpr index_t KPerWave = K; +}; + +template +static void RunTest(const HostTensor& A, + const HostTensor& B, + const HostTensor& ScaleA, + const HostTensor& ScaleB, + HostTensor& C) +{ + DeviceMem Ad(A), Bd(B), Cd(C), SAd(ScaleA), SBd(ScaleB); + dim3 grid(1), block{32}; + + using K = WarpGemmScale16Kernel; + + (void)launch_kernel(stream_config{nullptr, true, 0, 0, 1}, + make_kernel(K{}, + grid, + block, + 0, + Ad.GetDeviceBuffer(), + Bd.GetDeviceBuffer(), + Cd.GetDeviceBuffer(), + SAd.GetDeviceBuffer(), + SBd.GetDeviceBuffer())); + + Cd.FromDevice(C.mData.data()); +} + +using WGDispatcherTypesList = ::testing::Types, + WGDispCase, + WGDispCase, + WGDispCase>; + +template +class WGScale16Test : public ::testing::Test +{ +}; + +TYPED_TEST_SUITE(WGScale16Test, WGDispatcherTypesList); + +TYPED_TEST(WGScale16Test, Scale16_16x16x128_UniformScale) +{ + using Case = TypeParam; + using AType = typename Case::AType; + using BType = typename Case::BType; + using CType = typename Case::AccType; + + constexpr index_t M = Case::MPerWave; + constexpr index_t N = Case::NPerWave; + constexpr index_t K = Case::KPerWave; + constexpr index_t NumScales = K / 16; + + HostTensor A({M, K}); + HostTensor B({N, K}); + HostTensor C({M, N}); + HostTensor sA({M, NumScales}); + HostTensor sB({N, NumScales}); + + FillUniformDistribution{-5.f, 5.f}(A); + FillUniformDistribution{-5.f, 5.f}(B); + C.SetZero(); + FillConstant{e8m0_t{2.f}}(sA); + FillConstant{e8m0_t{4.f}}(sB); + + RunTest(A, B, sA, sB, C); + + HostTensor C_ref({M, N}); + C_ref.SetZero(); + reference_mx_gemm( + A, B.transpose(), C_ref, sA, sB.transpose()); + + EXPECT_TRUE(check_err(C, C_ref, "Scale16 uniform scale error.")); +} + +TYPED_TEST(WGScale16Test, Scale16_16x16x128_RandomDataRandomScales) +{ + using Case = TypeParam; + using AType = typename Case::AType; + using BType = typename Case::BType; + using CType = typename Case::AccType; + + constexpr index_t M = Case::MPerWave; + constexpr index_t N = Case::NPerWave; + constexpr index_t K = Case::KPerWave; + constexpr index_t NumScales = K / 16; + + HostTensor A({M, K}); + HostTensor B({N, K}); + HostTensor C({M, N}); + HostTensor sA({M, NumScales}); + HostTensor sB({N, NumScales}); + + FillUniformDistribution{-5.f, 5.f, 42}(A); + FillUniformDistribution{-5.f, 5.f, 137}(B); + C.SetZero(); + + { + constexpr int bias = ck_tile::numeric_traits::bias; + std::mt19937 gen(9999); + std::uniform_int_distribution dist(bias - 4, bias + 2); + for(auto& s : sA.mData) + s = e8m0_t(static_cast(dist(gen))); + for(auto& s : sB.mData) + s = e8m0_t(static_cast(dist(gen))); + } + + RunTest(A, B, sA, sB, C); + + HostTensor C_ref({M, N}); + C_ref.SetZero(); + reference_mx_gemm( + A, B.transpose(), C_ref, sA, sB.transpose()); + + const float max_acc = *std::max_element(C_ref.mData.begin(), C_ref.mData.end()); + const auto rtol = ck_tile::get_relative_threshold(K); + const auto atol = ck_tile::get_absolute_threshold(max_acc, K); + EXPECT_TRUE(check_err(C, C_ref, "Scale16 random data + random scales error.", rtol, atol)); +} + +TYPED_TEST(WGScale16Test, Scale16_16x16x128_TransposeC) +{ + using Case = TypeParam; + using AType = typename Case::AType; + using BType = typename Case::BType; + using CType = typename Case::AccType; + + constexpr index_t M = Case::MPerWave; + constexpr index_t N = Case::NPerWave; + constexpr index_t K = Case::KPerWave; + constexpr index_t NumScales = K / 16; + + HostTensor A({M, K}); + HostTensor B({N, K}); + HostTensor C({M, N}); + HostTensor sA({M, NumScales}); + HostTensor sB({N, NumScales}); + + FillUniformDistribution{-5.f, 5.f, 77}(A); + FillUniformDistribution{-5.f, 5.f, 88}(B); + C.SetZero(); + + { + constexpr int bias = ck_tile::numeric_traits::bias; + std::mt19937 gen(5555); + std::uniform_int_distribution dist(bias - 4, bias + 2); + for(auto& s : sA.mData) + s = e8m0_t(static_cast(dist(gen))); + for(auto& s : sB.mData) + s = e8m0_t(static_cast(dist(gen))); + } + + RunTest(A, B, sA, sB, C); + + HostTensor C_ref({M, N}); + C_ref.SetZero(); + reference_mx_gemm( + A, B.transpose(), C_ref, sA, sB.transpose()); + + const float max_acc = *std::max_element(C_ref.mData.begin(), C_ref.mData.end()); + const auto rtol = ck_tile::get_relative_threshold(K); + const auto atol = ck_tile::get_absolute_threshold(max_acc, K); + EXPECT_TRUE(check_err(C, C_ref, "Scale16 TransposeC error.", rtol, atol)); +} diff --git a/test/ck_tile/warp_gemm/test_f32_32x32x128_fp8_scale16.cpp b/test/ck_tile/warp_gemm/test_f32_32x32x128_fp8_scale16.cpp new file mode 100644 index 0000000000..46726108a1 --- /dev/null +++ b/test/ck_tile/warp_gemm/test_f32_32x32x128_fp8_scale16.cpp @@ -0,0 +1,362 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +// +// Unit test for a 32x32x128 block-scaled GEMM built from the 16x16x128 scale16 +// warp gemm (V_WMMA_SCALE16_F32_16X16X128_F8F6F4), modeled after the MX GEMM +// pipelines (BlockGemmARegBRegCRegV1 / mx_flatmm pipeline): all A/B sub-tiles and +// their per-block K-scales for the 2x2 block-level loop are preloaded into +// registers in advance, then consumed by the warp-gemm loop. OpSelA/OpSelB are +// threaded through the warp-gemm call as the pipeline does; for a 16x16 tile the +// hardware A/B sub-block index is 0, so per-M-block / per-N-block scale selection +// is realized by indexing the preloaded per-block scales (each int64 carries that +// block's 8 e8m0 K-scales). + +#include +#include "ck_tile/host.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" + +using namespace ck_tile; + +// Distribution that stages one scale row per lane: NumRows rows mapped across the +// NumRows lanes of the wave (lane L holds row L's NumScales e8m0 K-scales = one int64). +// For a 32-row tile this yields lanes 0..15 -> rows 0..15, lanes 16..31 -> rows 16..31, +// which is exactly the layout the hardware selects between via SCL_OPSEL (OpSel). +template +CK_TILE_DEVICE static constexpr auto MakeScaleDistribution() +{ + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<0>>{}); +} + +// Pipeline-style kernel: preload all scales for the 2x2 block in advance, then +// run a 2x2 block-level loop over the 16x16x128 scale16 warp gemm. Each (mIter, +// nIter) sub-tile consumes its own independent per-M/per-N-block K-scales. +template +struct WarpGemmScale16BlockLoopKernel +{ + static constexpr int kBlockSize = 32; + static constexpr index_t ScaleBlockK = 16; + static constexpr index_t NumScales = K / ScaleBlockK; + static constexpr index_t MPerWarp = 16; + static constexpr index_t NPerWarp = 16; + static constexpr index_t MIter = M / MPerWarp; + static constexpr index_t NIter = N / NPerWarp; + + __device__ void operator()(void* A, void* B, void* C, void* ScaleA, void* ScaleB) const + { + const auto a_view = + make_naive_tensor_view(static_cast(A), + make_tuple(M, K), + make_tuple(K, number<1>{}), + number{}, + number<1>{}); + const auto b_view = + make_naive_tensor_view(static_cast(B), + make_tuple(N, K), + make_tuple(K, number<1>{}), + number{}, + number<1>{}); + const auto c_view = + make_naive_tensor_view(static_cast(C), + make_tuple(M, N), + make_tuple(N, number<1>{}), + number{}, + number<1>{}); + const auto sa_view = + make_naive_tensor_view(static_cast(ScaleA), + make_tuple(M, NumScales), + make_tuple(NumScales, number<1>{}), + number{}, + number<1>{}); + const auto sb_view = + make_naive_tensor_view(static_cast(ScaleB), + make_tuple(N, NumScales), + make_tuple(NumScales, number<1>{}), + number{}, + number<1>{}); + + using WarpGemm = WarpGemmDispatcher< + AType, // ADataType: A element type (fp8_t / bf8_t) + BType, // BDataType: B element type (fp8_t / bf8_t) + float, // AccDataType: accumulator type (F32) + MPerWarp, // MPerWave: warp-tile M (16 - native op size) + NPerWarp, // NPerWave: warp-tile N (16 - native op size) + K, // KPerWave: warp-tile K (128) + TransposeC, // TransposeC: use transposed-C distribution + false, // SwizzleA: A LDS swizzle layout (off) + false, // UseStructuredSparsity: 2:4 sparsity (off) + WGAttrNumAccessEnum::Default, // AttrNumAccessA: A num-access attribute + WGAttrNumAccessEnum::Default, // AttrNumAccessB: B num-access attribute + true>; // IsScale16: select the scale16 WMMA variant + + constexpr auto a_dstr = typename WarpGemm::AWarpDstr{}; + constexpr auto b_dstr = typename WarpGemm::BWarpDstr{}; + constexpr auto c_dstr = typename WarpGemm::CWarpDstr{}; + constexpr auto scale_dstr = MakeScaleDistribution(); + + // ---- Preload phase ---- + // A/B sub-tiles for the 2x2 block loop. Scales are staged ONCE across all 32 + // lanes (not per-block arrays): a single int64 A-scale and B-scale carry every + // row's/col's K-scales, and the per-block selection is done in hardware by + // OpSelA/OpSelB (SCL_OPSEL) -- lanes 0..15 for block 0, lanes 16..31 for block 1. + statically_indexed_array{}, number{}), + make_multi_index(0, 0), + a_dstr))), + MIter> + a_tiles; + static_for<0, MIter, 1>{}([&](auto mIter) { + auto a_win = make_tile_window(a_view, + make_tuple(number{}, number{}), + make_multi_index(mIter.value * MPerWarp, 0), + a_dstr); + a_tiles(mIter) = load_tile(a_win); + }); + + statically_indexed_array{}, number{}), + make_multi_index(0, 0), + b_dstr))), + NIter> + b_tiles; + static_for<0, NIter, 1>{}([&](auto nIter) { + auto b_win = make_tile_window(b_view, + make_tuple(number{}, number{}), + make_multi_index(nIter.value * NPerWarp, 0), + b_dstr); + b_tiles(nIter) = load_tile(b_win); + }); + + // Single 64-bit scale scalars: all M rows / N cols staged across the 32 lanes. + auto sa_win = make_tile_window(sa_view, + make_tuple(number{}, number{}), + make_multi_index(0, 0), + scale_dstr); + const int64_t scale_a = + bit_cast(load_tile(sa_win) + .get_thread_buffer() + .template get_as>()[number<0>{}]); + auto sb_win = make_tile_window(sb_view, + make_tuple(number{}, number{}), + make_multi_index(0, 0), + scale_dstr); + const int64_t scale_b = + bit_cast(load_tile(sb_win) + .get_thread_buffer() + .template get_as>()[number<0>{}]); + + // ---- Compute phase ---- + // 2x2 block-level loop. The same int64 scale scalars are passed to every call; + // the correct per-block K-scales are fetched in hardware via OpSelA/OpSelB, which + // select lanes 0..15 (block 0) vs lanes 16..31 (block 1) of the scale register. + static_for<0, MIter, 1>{}([&](auto mIter) { + static_for<0, NIter, 1>{}([&](auto nIter) { + auto c_win = make_tile_window( + c_view, + make_tuple(number{}, number{}), + make_multi_index(mIter.value * MPerWarp, nIter.value * NPerWarp), + c_dstr); + + auto c_tile = + WarpGemm{}.template operator(), OpSelB>( + a_tiles(mIter), b_tiles(nIter), scale_a, scale_b); + store_tile(c_win, c_tile); + }); + }); + } +}; + +template +struct WGDispCase +{ + using AType = A; + using BType = B; + using AccType = Acc; + static constexpr index_t MPerWave = M; + static constexpr index_t NPerWave = N; + static constexpr index_t KPerWave = K; +}; + +template +static void RunTest(const HostTensor& A, + const HostTensor& B, + const HostTensor& ScaleA, + const HostTensor& ScaleB, + HostTensor& C) +{ + DeviceMem Ad(A), Bd(B), Cd(C), SAd(ScaleA), SBd(ScaleB); + dim3 grid(1), block{32}; + + using K = WarpGemmScale16BlockLoopKernel; + + (void)launch_kernel(stream_config{nullptr, true, 0, 0, 1}, + make_kernel(K{}, + grid, + block, + 0, + Ad.GetDeviceBuffer(), + Bd.GetDeviceBuffer(), + Cd.GetDeviceBuffer(), + SAd.GetDeviceBuffer(), + SBd.GetDeviceBuffer())); + + Cd.FromDevice(C.mData.data()); +} + +using WGDispatcherTypesList = ::testing::Types, + WGDispCase, + WGDispCase, + WGDispCase>; + +template +class WGScale16Test : public ::testing::Test +{ +}; + +TYPED_TEST_SUITE(WGScale16Test, WGDispatcherTypesList); + +TYPED_TEST(WGScale16Test, Scale16_32x32x128_UniformScale) +{ + using Case = TypeParam; + using AType = typename Case::AType; + using BType = typename Case::BType; + using CType = typename Case::AccType; + + constexpr index_t M = Case::MPerWave; + constexpr index_t N = Case::NPerWave; + constexpr index_t K = Case::KPerWave; + constexpr index_t NumScales = K / 16; + + HostTensor A({M, K}); + HostTensor B({N, K}); + HostTensor C({M, N}); + HostTensor sA({M, NumScales}); + HostTensor sB({N, NumScales}); + + FillUniformDistribution{-5.f, 5.f}(A); + FillUniformDistribution{-5.f, 5.f}(B); + C.SetZero(); + FillConstant{e8m0_t{2.f}}(sA); + FillConstant{e8m0_t{4.f}}(sB); + + RunTest(A, B, sA, sB, C); + + HostTensor C_ref({M, N}); + C_ref.SetZero(); + reference_mx_gemm( + A, B.transpose(), C_ref, sA, sB.transpose()); + + EXPECT_TRUE(check_err(C, C_ref, "Scale16 32x32x128 uniform scale error.")); +} + +TYPED_TEST(WGScale16Test, Scale16_32x32x128_RandomDataRandomScales) +{ + using Case = TypeParam; + using AType = typename Case::AType; + using BType = typename Case::BType; + using CType = typename Case::AccType; + + constexpr index_t M = Case::MPerWave; + constexpr index_t N = Case::NPerWave; + constexpr index_t K = Case::KPerWave; + constexpr index_t NumScales = K / 16; + + HostTensor A({M, K}); + HostTensor B({N, K}); + HostTensor C({M, N}); + HostTensor sA({M, NumScales}); + HostTensor sB({N, NumScales}); + + FillUniformDistribution{-5.f, 5.f, 42}(A); + FillUniformDistribution{-5.f, 5.f, 137}(B); + C.SetZero(); + + { + constexpr int bias = ck_tile::numeric_traits::bias; + std::mt19937 gen(9999); + std::uniform_int_distribution dist(bias - 4, bias + 2); + for(auto& s : sA.mData) + s = e8m0_t(static_cast(dist(gen))); + for(auto& s : sB.mData) + s = e8m0_t(static_cast(dist(gen))); + } + + RunTest(A, B, sA, sB, C); + + HostTensor C_ref({M, N}); + C_ref.SetZero(); + reference_mx_gemm( + A, B.transpose(), C_ref, sA, sB.transpose()); + + const float max_acc = *std::max_element(C_ref.mData.begin(), C_ref.mData.end()); + const auto rtol = ck_tile::get_relative_threshold(K); + const auto atol = ck_tile::get_absolute_threshold(max_acc, K); + EXPECT_TRUE( + check_err(C, C_ref, "Scale16 32x32x128 random data + random scales error.", rtol, atol)); +} + +TYPED_TEST(WGScale16Test, Scale16_32x32x128_TransposeC) +{ + using Case = TypeParam; + using AType = typename Case::AType; + using BType = typename Case::BType; + using CType = typename Case::AccType; + + constexpr index_t M = Case::MPerWave; + constexpr index_t N = Case::NPerWave; + constexpr index_t K = Case::KPerWave; + constexpr index_t NumScales = K / 16; + + HostTensor A({M, K}); + HostTensor B({N, K}); + HostTensor C({M, N}); + HostTensor sA({M, NumScales}); + HostTensor sB({N, NumScales}); + + FillUniformDistribution{-5.f, 5.f, 77}(A); + FillUniformDistribution{-5.f, 5.f, 88}(B); + C.SetZero(); + + { + constexpr int bias = ck_tile::numeric_traits::bias; + std::mt19937 gen(5555); + std::uniform_int_distribution dist(bias - 4, bias + 2); + for(auto& s : sA.mData) + s = e8m0_t(static_cast(dist(gen))); + for(auto& s : sB.mData) + s = e8m0_t(static_cast(dist(gen))); + } + + RunTest(A, B, sA, sB, C); + + HostTensor C_ref({M, N}); + C_ref.SetZero(); + reference_mx_gemm( + A, B.transpose(), C_ref, sA, sB.transpose()); + + const float max_acc = *std::max_element(C_ref.mData.begin(), C_ref.mData.end()); + const auto rtol = ck_tile::get_relative_threshold(K); + const auto atol = ck_tile::get_absolute_threshold(max_acc, K); + EXPECT_TRUE(check_err(C, C_ref, "Scale16 32x32x128 TransposeC error.", rtol, atol)); +}