[rocm-libraries] ROCm/rocm-libraries#7725 (commit eef7e12)

[GFX1250][CK_TILE] Add scale16 warp gemm unit tests

## Summary
- Add scale16 WMMA intrinsic overloads and int64_t forwarding to warp
gemm layers for gfx1250
- Add comprehensive wave-level unit tests for scale16 warp gemm
(16x16x128 and 32x32x128 tile sizes)
- Test all fp8/bf8 type combinations and TransposeC variants
- Fix WarpGemm wrapper for non-uniform scale16 configurations

Stacked on #7724 (FillUniformScaleDistribution / MX GEMM scale init).
Pipeline enablement follows in the next PR.
This commit is contained in:
Aviral Goel
2026-06-03 22:05:29 +00:00
committed by assistant-librarian[bot]
parent 45a8f96c66
commit e01603bc31
8 changed files with 1019 additions and 89 deletions

View File

@@ -154,6 +154,19 @@ struct CTransposedWarpDstrEncodingTrait
typename Impl::kCTYs2RHsMinor>;
};
namespace detail {
template <typename T, typename = void>
struct mx_type_enable_or_void
{
using type = void;
};
template <typename T>
struct mx_type_enable_or_void<T, std::void_t<typename T::MXTypeEnableType>>
{
using type = typename T::MXTypeEnableType;
};
} // namespace detail
template <typename WarpGemmAttributeWmmaImpl_,
bool kTransC = false,
WGAttrNumAccessEnum AttrNumAccessA = WGAttrNumAccessEnum::Single,
@@ -162,18 +175,21 @@ struct WarpGemmAttributeWmma
{
using Impl = remove_cvref_t<WarpGemmAttributeWmmaImpl_>;
// When kTransC is true and A/B types differ, we need an impl with swapped types
using TransposedImpl =
std::conditional_t<kTransC &&
!std::is_same_v<typename Impl::ADataType, typename Impl::BDataType>,
WarpGemmAttributeWmmaImpl<WmmaTraits<typename Impl::TraitsType::ArchType,
typename Impl::BDataType,
typename Impl::ADataType,
typename Impl::CDataType,
Impl::kM,
Impl::kN,
Impl::kK>>,
Impl>;
// When kTransC is true and A/B types differ, we need an impl with swapped types.
// Propagate MXTypeEnable (e.g., WmmaScale16Tag) so the transposed impl uses the
// same WmmaTraits specialization family.
using TransposedImpl = std::conditional_t<
kTransC && !std::is_same_v<typename Impl::ADataType, typename Impl::BDataType>,
WarpGemmAttributeWmmaImpl<
WmmaTraits<typename Impl::TraitsType::ArchType,
typename Impl::BDataType,
typename Impl::ADataType,
typename Impl::CDataType,
Impl::kM,
Impl::kN,
Impl::kK,
typename detail::mx_type_enable_or_void<typename Impl::TraitsType>::type>>,
Impl>;
using ADataType = typename Impl::ADataType;
using BDataType = typename Impl::BDataType;

View File

@@ -204,6 +204,13 @@ template <typename AType, typename BType>
using WarpGemmAttributeWmmaImpl_f32_16x16x128_f8f6f4 =
WarpGemmAttributeWmmaImpl<WmmaTraits<gfx125_t, AType, BType, float, 16, 16, 128>>;
// WmmaScale16Tag (declared above) is passed as MXTypeEnable to WmmaTraits to select scale16
// specializations. These override kAK1PerLane=16 (-> sequence<4,2,16>) and use int64_t scales
// for V_WMMA_SCALE16_F32_16X16X128_F8F6F4, vs the default layout / int32_t.
template <typename AType, typename BType>
using WarpGemmAttributeWmmaImpl_f32_16x16x128_f8f6f4_scale16 = WarpGemmAttributeWmmaImpl<
WmmaTraits<gfx125_t, AType, BType, float, 16, 16, 128, WmmaScale16Tag>>;
template <typename AType, typename BType>
using WarpGemmAttributeWmmaImpl_f32_32x32x128_f8f6f4 =
WarpGemmAttributeWmmaImpl<WmmaTraits<gfx125_t, AType, BType, float, 32, 32, 128>>;

View File

@@ -427,6 +427,82 @@ struct WmmaTraits<gfx125_t, bf8_t, bf8_t, fp16_t, 16, 16, 64>
}
};
// f8f6f4 specialization - GFX125
enum F8F6F4OpDataTypeEnum
{
E4M3, // 0x0
E5M2, // 0x1
E2M3, // 0x2
E3M2, // 0x3
E2M1, // 0x4
};
// Traits for MX data types used in f8f6f4 intrinsics
template <typename T>
struct MXDataTypeTrait;
template <>
struct MXDataTypeTrait<fp8_t>
{
static constexpr F8F6F4OpDataTypeEnum OpDataType = F8F6F4OpDataTypeEnum::E4M3;
using VecType = int32x16_t;
CK_TILE_DEVICE static int32x16_t to_wmma_vec(const int32x16_t& vec) { return vec; }
};
template <>
struct MXDataTypeTrait<bf8_t>
{
static constexpr F8F6F4OpDataTypeEnum OpDataType = F8F6F4OpDataTypeEnum::E5M2;
using VecType = int32x16_t;
CK_TILE_DEVICE static int32x16_t to_wmma_vec(const int32x16_t& vec) { return vec; }
};
template <>
struct MXDataTypeTrait<pk_fp4_t>
{
static constexpr F8F6F4OpDataTypeEnum OpDataType = F8F6F4OpDataTypeEnum::E2M1;
using VecType = int32x8_t;
CK_TILE_DEVICE static int32x16_t to_wmma_vec(const int32x8_t& vec)
{
return int32x16_t{
vec[0], vec[1], vec[2], vec[3], vec[4], vec[5], vec[6], vec[7], 0, 0, 0, 0, 0, 0, 0, 0};
}
};
// pk_fp6x16_t (legacy): 16 fp6 e2m3 values packed into 3 int32 (96 bits).
// At 16x16x128 each lane holds 64 fp6 elements = 4 packs = 12 int32
// (f6x16xN_tt<4, f6_kind::fp6>, whose storage is int32_t data[12]),
// padded with 4 zero lanes to fit the 16-wide f8f6f4 wmma input.
template <>
struct MXDataTypeTrait<pk_fp6x16_t>
{
static constexpr F8F6F4OpDataTypeEnum OpDataType = F8F6F4OpDataTypeEnum::E2M3;
using VecType = f6x16xN_tt<4, f6_kind::fp6>;
CK_TILE_DEVICE static int32x16_t to_wmma_vec(const f6x16xN_tt<4, f6_kind::fp6>& vec)
{
return int32x16_t{vec.data[0],
vec.data[1],
vec.data[2],
vec.data[3],
vec.data[4],
vec.data[5],
vec.data[6],
vec.data[7],
vec.data[8],
vec.data[9],
vec.data[10],
vec.data[11],
0,
0,
0,
0};
}
};
template <>
struct WmmaTraits<gfx125_t, bf8_t, bf8_t, float, 16, 16, 128>
: WmmaTraitsBase<gfx12_t, bf8_t, bf8_t, float, 128>
@@ -508,6 +584,232 @@ struct WmmaTraits<gfx125_t, bf8_t, fp8_t, float, 16, 16, 128>
}
};
// scale16 specializations: fp8xfp8, bf8xbf8, fp8xbf8, bf8xfp8
// Override kAK1PerLane/kBK1PerLane to 16 for scale16 register layout -> sequence<4,2,16>
template <>
struct WmmaTraits<gfx125_t, fp8_t, fp8_t, float, 16, 16, 128, WmmaScale16Tag>
: WmmaTraitsBase<gfx12_t, fp8_t, fp8_t, float, 128>
{
using ArchType = gfx125_t;
using MXTypeEnableType = WmmaScale16Tag;
static constexpr index_t kAK1PerLane = 16;
static constexpr index_t kAK0PerLane = kK / (kAK1PerLane * kABKLane);
static constexpr index_t kBK1PerLane = 16;
static constexpr index_t kBK0PerLane = kK / (kBK1PerLane * kABKLane);
template <typename... Params>
CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType&, const BVecType&, const CVecType&)
{
static_assert(sizeof...(Params) < 0, "scale16 WmmaTraits requires int64_t scale arguments");
return CVecType{0};
}
template <typename... Params>
CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType& a_vec,
const int64_t& a_scale,
const BVecType& b_vec,
const int64_t& b_scale,
const CVecType& c_vec)
{
#ifdef __gfx125__
using P = WarpGemmParamsParser<Params...>;
using ATraits = MXDataTypeTrait<fp8_t>;
using BTraits = MXDataTypeTrait<fp8_t>;
return __builtin_amdgcn_wmma_scale16_f32_16x16x128_f8f6f4(
ATraits::OpDataType,
ATraits::to_wmma_vec(bit_cast<typename ATraits::VecType>(a_vec)),
BTraits::OpDataType,
BTraits::to_wmma_vec(bit_cast<typename BTraits::VecType>(b_vec)),
0,
bit_cast<fp32x8_t>(c_vec),
P::op_sel_a,
P::scale_a,
a_scale,
P::op_sel_b,
P::scale_b,
b_scale,
0,
0);
#else
ck_tile::ignore = a_vec;
ck_tile::ignore = a_scale;
ck_tile::ignore = b_vec;
ck_tile::ignore = b_scale;
ck_tile::ignore = c_vec;
return CVecType{0};
#endif
}
};
template <>
struct WmmaTraits<gfx125_t, bf8_t, bf8_t, float, 16, 16, 128, WmmaScale16Tag>
: WmmaTraitsBase<gfx12_t, bf8_t, bf8_t, float, 128>
{
using ArchType = gfx125_t;
using MXTypeEnableType = WmmaScale16Tag;
static constexpr index_t kAK1PerLane = 16;
static constexpr index_t kAK0PerLane = kK / (kAK1PerLane * kABKLane);
static constexpr index_t kBK1PerLane = 16;
static constexpr index_t kBK0PerLane = kK / (kBK1PerLane * kABKLane);
template <typename... Params>
CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType&, const BVecType&, const CVecType&)
{
static_assert(sizeof...(Params) < 0, "scale16 WmmaTraits requires int64_t scale arguments");
return CVecType{0};
}
template <typename... Params>
CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType& a_vec,
const int64_t& a_scale,
const BVecType& b_vec,
const int64_t& b_scale,
const CVecType& c_vec)
{
#ifdef __gfx125__
using P = WarpGemmParamsParser<Params...>;
using ATraits = MXDataTypeTrait<bf8_t>;
using BTraits = MXDataTypeTrait<bf8_t>;
return __builtin_amdgcn_wmma_scale16_f32_16x16x128_f8f6f4(
ATraits::OpDataType,
ATraits::to_wmma_vec(bit_cast<typename ATraits::VecType>(a_vec)),
BTraits::OpDataType,
BTraits::to_wmma_vec(bit_cast<typename BTraits::VecType>(b_vec)),
0,
bit_cast<fp32x8_t>(c_vec),
P::op_sel_a,
P::scale_a,
a_scale,
P::op_sel_b,
P::scale_b,
b_scale,
0,
0);
#else
ck_tile::ignore = a_vec;
ck_tile::ignore = a_scale;
ck_tile::ignore = b_vec;
ck_tile::ignore = b_scale;
ck_tile::ignore = c_vec;
return CVecType{0};
#endif
}
};
template <>
struct WmmaTraits<gfx125_t, fp8_t, bf8_t, float, 16, 16, 128, WmmaScale16Tag>
: WmmaTraitsBase<gfx12_t, fp8_t, bf8_t, float, 128>
{
using ArchType = gfx125_t;
using MXTypeEnableType = WmmaScale16Tag;
static constexpr index_t kAK1PerLane = 16;
static constexpr index_t kAK0PerLane = kK / (kAK1PerLane * kABKLane);
static constexpr index_t kBK1PerLane = 16;
static constexpr index_t kBK0PerLane = kK / (kBK1PerLane * kABKLane);
template <typename... Params>
CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType&, const BVecType&, const CVecType&)
{
static_assert(sizeof...(Params) < 0, "scale16 WmmaTraits requires int64_t scale arguments");
return CVecType{0};
}
template <typename... Params>
CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType& a_vec,
const int64_t& a_scale,
const BVecType& b_vec,
const int64_t& b_scale,
const CVecType& c_vec)
{
#ifdef __gfx125__
using P = WarpGemmParamsParser<Params...>;
using ATraits = MXDataTypeTrait<fp8_t>;
using BTraits = MXDataTypeTrait<bf8_t>;
return __builtin_amdgcn_wmma_scale16_f32_16x16x128_f8f6f4(
ATraits::OpDataType,
ATraits::to_wmma_vec(bit_cast<typename ATraits::VecType>(a_vec)),
BTraits::OpDataType,
BTraits::to_wmma_vec(bit_cast<typename BTraits::VecType>(b_vec)),
0,
bit_cast<fp32x8_t>(c_vec),
P::op_sel_a,
P::scale_a,
a_scale,
P::op_sel_b,
P::scale_b,
b_scale,
0,
0);
#else
ck_tile::ignore = a_vec;
ck_tile::ignore = a_scale;
ck_tile::ignore = b_vec;
ck_tile::ignore = b_scale;
ck_tile::ignore = c_vec;
return CVecType{0};
#endif
}
};
template <>
struct WmmaTraits<gfx125_t, bf8_t, fp8_t, float, 16, 16, 128, WmmaScale16Tag>
: WmmaTraitsBase<gfx12_t, bf8_t, fp8_t, float, 128>
{
using ArchType = gfx125_t;
using MXTypeEnableType = WmmaScale16Tag;
static constexpr index_t kAK1PerLane = 16;
static constexpr index_t kAK0PerLane = kK / (kAK1PerLane * kABKLane);
static constexpr index_t kBK1PerLane = 16;
static constexpr index_t kBK0PerLane = kK / (kBK1PerLane * kABKLane);
template <typename... Params>
CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType&, const BVecType&, const CVecType&)
{
static_assert(sizeof...(Params) < 0, "scale16 WmmaTraits requires int64_t scale arguments");
return CVecType{0};
}
template <typename... Params>
CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType& a_vec,
const int64_t& a_scale,
const BVecType& b_vec,
const int64_t& b_scale,
const CVecType& c_vec)
{
#ifdef __gfx125__
using P = WarpGemmParamsParser<Params...>;
using ATraits = MXDataTypeTrait<bf8_t>;
using BTraits = MXDataTypeTrait<fp8_t>;
return __builtin_amdgcn_wmma_scale16_f32_16x16x128_f8f6f4(
ATraits::OpDataType,
ATraits::to_wmma_vec(bit_cast<typename ATraits::VecType>(a_vec)),
BTraits::OpDataType,
BTraits::to_wmma_vec(bit_cast<typename BTraits::VecType>(b_vec)),
0,
bit_cast<fp32x8_t>(c_vec),
P::op_sel_a,
P::scale_a,
a_scale,
P::op_sel_b,
P::scale_b,
b_scale,
0,
0);
#else
ck_tile::ignore = a_vec;
ck_tile::ignore = a_scale;
ck_tile::ignore = b_vec;
ck_tile::ignore = b_scale;
ck_tile::ignore = c_vec;
return CVecType{0};
#endif
}
};
// 32x16x128 f4 specialization - GFX125
template <>
struct WmmaTraits<gfx125_t, pk_fp4_t, pk_fp4_t, float, 32, 16, 128>
@@ -648,82 +950,6 @@ struct WmmaTraits<gfx125_t, pk_fp4_t, pk_fp4_t, float, 32, 32, 128, WmmaScale16T
{
};
// f8f6f4 specialization - GFX125
enum F8F6F4OpDataTypeEnum
{
E4M3, // 0x0
E5M2, // 0x1
E2M3, // 0x2
E3M2, // 0x3
E2M1, // 0x4
};
// Traits for MX data types used in f8f6f4 intrinsics
template <typename T>
struct MXDataTypeTrait;
template <>
struct MXDataTypeTrait<fp8_t>
{
static constexpr F8F6F4OpDataTypeEnum OpDataType = F8F6F4OpDataTypeEnum::E4M3;
using VecType = int32x16_t;
CK_TILE_DEVICE static int32x16_t to_wmma_vec(const int32x16_t& vec) { return vec; }
};
template <>
struct MXDataTypeTrait<bf8_t>
{
static constexpr F8F6F4OpDataTypeEnum OpDataType = F8F6F4OpDataTypeEnum::E5M2;
using VecType = int32x16_t;
CK_TILE_DEVICE static int32x16_t to_wmma_vec(const int32x16_t& vec) { return vec; }
};
template <>
struct MXDataTypeTrait<pk_fp4_t>
{
static constexpr F8F6F4OpDataTypeEnum OpDataType = F8F6F4OpDataTypeEnum::E2M1;
using VecType = int32x8_t;
CK_TILE_DEVICE static int32x16_t to_wmma_vec(const int32x8_t& vec)
{
return int32x16_t{
vec[0], vec[1], vec[2], vec[3], vec[4], vec[5], vec[6], vec[7], 0, 0, 0, 0, 0, 0, 0, 0};
}
};
// pk_fp6x16_t (legacy): 16 fp6 e2m3 values packed into 3 int32 (96 bits).
// At 16x16x128 each lane holds 64 fp6 elements = 4 packs = 12 int32
// (f6x16xN_tt<4, f6_kind::fp6>, whose storage is int32_t data[12]),
// padded with 4 zero lanes to fit the 16-wide f8f6f4 wmma input.
template <>
struct MXDataTypeTrait<pk_fp6x16_t>
{
static constexpr F8F6F4OpDataTypeEnum OpDataType = F8F6F4OpDataTypeEnum::E2M3;
using VecType = f6x16xN_tt<4, f6_kind::fp6>;
CK_TILE_DEVICE static int32x16_t to_wmma_vec(const f6x16xN_tt<4, f6_kind::fp6>& vec)
{
return int32x16_t{vec.data[0],
vec.data[1],
vec.data[2],
vec.data[3],
vec.data[4],
vec.data[5],
vec.data[6],
vec.data[7],
vec.data[8],
vec.data[9],
vec.data[10],
vec.data[11],
0,
0,
0,
0};
}
};
// Unified WmmaTraits for f8f6f4 combinations
template <typename AType, typename BType>
struct WmmaTraits<

View File

@@ -240,6 +240,9 @@ template<bool TransposeC, WGAttrNumAccessEnum AttrNumAccess> struct Dispatcher<b
// F8F6F4 Mixed precision cases
template<typename A, typename B, bool TransposeC, WGAttrNumAccessEnum AttrNumAccessA, WGAttrNumAccessEnum AttrNumAccessB> struct Dispatcher<A, B, float, 16, 16, 128, TransposeC, false, false, AttrNumAccessA, AttrNumAccessB> : WmmaTag { using Type = WarpGemmWmma_f32_16x16x128_f8f6f4<A, B, TransposeC, AttrNumAccessA, AttrNumAccessB>; };
// F8F6F4 Scale16 (IsScale16=true)
template<typename A, typename B, bool TransposeC, WGAttrNumAccessEnum AttrNumAccessA, WGAttrNumAccessEnum AttrNumAccessB> struct Dispatcher<A, B, float, 16, 16, 128, TransposeC, false, false, AttrNumAccessA, AttrNumAccessB, true> : WmmaTag { using Type = WarpGemmWmma_f32_16x16x128_f8f6f4_scale16<A, B, TransposeC, AttrNumAccessA, AttrNumAccessB>; };
#else
template<> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 64, false> { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8<>; };
template<> struct Dispatcher<bf8_t, bf8_t, float, 16, 16, 64, false> { using Type = WarpGemmMfma_f32_16x16x64_bf8_bf8; };

View File

@@ -216,6 +216,17 @@ using WarpGemmWmma_f32_16x16x128_f8f6f4 =
AttrNumAccessA,
AttrNumAccessB>>;
template <typename AType,
typename BType,
bool kTransC,
WGAttrNumAccessEnum AttrNumAccessA = WGAttrNumAccessEnum::Default,
WGAttrNumAccessEnum AttrNumAccessB = WGAttrNumAccessEnum::Default>
using WarpGemmWmma_f32_16x16x128_f8f6f4_scale16 = WarpGemmImpl<
WarpGemmAttributeWmma<WarpGemmAttributeWmmaImpl_f32_16x16x128_f8f6f4_scale16<AType, BType>,
kTransC,
AttrNumAccessA,
AttrNumAccessB>>;
template <typename AType,
typename BType,
bool kTransC,

View File

@@ -4,8 +4,14 @@
if(GPU_TARGETS MATCHES "gfx95")
add_gtest_executable(test_ck_tile_wg_16x16x128_fp4 test_f32_16x16x128_fp4.cpp)
endif()
# Scale16 warp gemm tests for V_WMMA_SCALE16_F32_16X16X128_F8F6F4.
# Each test covers 4 type combos (fp8xfp8, bf8xbf8, fp8xbf8, bf8xfp8) x 3 cases
# (uniform scale, random scale, TransposeC) = 12 tests.
# 16x16: single warp gemm call. 32x32: pipeline-style 2x2 block loop with per-block scales.
if(GPU_TARGETS MATCHES "gfx125")
add_gtest_executable(test_ck_tile_wg_16x16x128_fp8_scale16 test_f32_16x16x128_fp8_scale16.cpp)
add_gtest_executable(test_ck_tile_wg_32x32x128_fp8_scale16 test_f32_32x32x128_fp8_scale16.cpp)
add_gtest_executable(test_ck_tile_wmma_bf16_16x16x32_gfx1250 test_wmma_bf16_16x16x32_gfx1250.cpp)
add_gtest_executable(test_ck_tile_wg_32x16x128_fp4 test_f32_32x16x128_fp4.cpp)
endif()

View File

@@ -0,0 +1,299 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <gtest/gtest.h>
#include "ck_tile/host.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
using namespace ck_tile;
template <index_t NumScales>
CK_TILE_DEVICE static constexpr auto MakeScaleDistribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<2>,
tuple<sequence<16>, sequence<NumScales>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 0>>,
sequence<2>,
sequence<0>>{});
}
// Scale16 kernel using the WarpGemm wrapper (Layer 4).
// WarpGemmWmma_f32_16x16x128_f8f6f4_scale16 produces AWarpDstr with sequence<4,2,16>,
// matching the hardware's 1-scale-byte-per-16-K-elements register layout.
template <typename AType,
typename BType,
typename CType,
index_t M,
index_t N,
index_t K,
bool TransposeC>
struct WarpGemmScale16Kernel
{
static constexpr int kBlockSize = 32;
static constexpr index_t ScaleBlockK = 16;
static constexpr index_t NumScales = K / ScaleBlockK;
__device__ void operator()(void* A, void* B, void* C, void* ScaleA, void* ScaleB) const
{
const auto a_view =
make_naive_tensor_view<address_space_enum::global>(static_cast<AType*>(A),
make_tuple(M, K),
make_tuple(K, number<1>{}),
number<K>{},
number<1>{});
const auto b_view =
make_naive_tensor_view<address_space_enum::global>(static_cast<BType*>(B),
make_tuple(N, K),
make_tuple(K, number<1>{}),
number<K>{},
number<1>{});
const auto c_view =
make_naive_tensor_view<address_space_enum::global>(static_cast<CType*>(C),
make_tuple(M, N),
make_tuple(N, number<1>{}),
number<N>{},
number<1>{});
const auto sa_view =
make_naive_tensor_view<address_space_enum::global>(static_cast<e8m0_t*>(ScaleA),
make_tuple(M, NumScales),
make_tuple(NumScales, number<1>{}),
number<NumScales>{},
number<1>{});
const auto sb_view =
make_naive_tensor_view<address_space_enum::global>(static_cast<e8m0_t*>(ScaleB),
make_tuple(N, NumScales),
make_tuple(NumScales, number<1>{}),
number<NumScales>{},
number<1>{});
using WarpGemm = WarpGemmDispatcher<AType,
BType,
float,
M,
N,
K,
TransposeC,
false,
false,
WGAttrNumAccessEnum::Default,
WGAttrNumAccessEnum::Default,
true>;
constexpr auto a_dstr = typename WarpGemm::AWarpDstr{};
constexpr auto b_dstr = typename WarpGemm::BWarpDstr{};
constexpr auto c_dstr = typename WarpGemm::CWarpDstr{};
constexpr auto scale_dstr = MakeScaleDistribution<NumScales>();
auto a_win = make_tile_window(
a_view, make_tuple(number<M>{}, number<K>{}), make_multi_index(0, 0), a_dstr);
auto b_win = make_tile_window(
b_view, make_tuple(number<N>{}, number<K>{}), make_multi_index(0, 0), b_dstr);
auto c_win = make_tile_window(
c_view, make_tuple(number<M>{}, number<N>{}), make_multi_index(0, 0), c_dstr);
auto sa_win = make_tile_window(sa_view,
make_tuple(number<M>{}, number<NumScales>{}),
make_multi_index(0, 0),
scale_dstr);
auto sb_win = make_tile_window(sb_view,
make_tuple(number<N>{}, number<NumScales>{}),
make_multi_index(0, 0),
scale_dstr);
auto a_tile = load_tile(a_win);
auto b_tile = load_tile(b_win);
auto sa_tile = load_tile(sa_win);
auto sb_tile = load_tile(sb_win);
int64_t scale_a =
bit_cast<int64_t>(sa_tile.get_thread_buffer()
.template get_as<ext_vector_t<e8m0_t, NumScales>>()[number<0>{}]);
int64_t scale_b =
bit_cast<int64_t>(sb_tile.get_thread_buffer()
.template get_as<ext_vector_t<e8m0_t, NumScales>>()[number<0>{}]);
auto c_tile = WarpGemm{}(a_tile, b_tile, scale_a, scale_b);
store_tile(c_win, c_tile);
}
};
template <typename A, typename B, typename Acc, index_t M, index_t N, index_t K>
struct WGDispCase
{
using AType = A;
using BType = B;
using AccType = Acc;
static constexpr index_t MPerWave = M;
static constexpr index_t NPerWave = N;
static constexpr index_t KPerWave = K;
};
template <typename Case, bool TransposeC>
static void RunTest(const HostTensor<typename Case::AType>& A,
const HostTensor<typename Case::BType>& B,
const HostTensor<e8m0_t>& ScaleA,
const HostTensor<e8m0_t>& ScaleB,
HostTensor<typename Case::AccType>& C)
{
DeviceMem Ad(A), Bd(B), Cd(C), SAd(ScaleA), SBd(ScaleB);
dim3 grid(1), block{32};
using K = WarpGemmScale16Kernel<typename Case::AType,
typename Case::BType,
typename Case::AccType,
Case::MPerWave,
Case::NPerWave,
Case::KPerWave,
TransposeC>;
(void)launch_kernel(stream_config{nullptr, true, 0, 0, 1},
make_kernel(K{},
grid,
block,
0,
Ad.GetDeviceBuffer(),
Bd.GetDeviceBuffer(),
Cd.GetDeviceBuffer(),
SAd.GetDeviceBuffer(),
SBd.GetDeviceBuffer()));
Cd.FromDevice(C.mData.data());
}
using WGDispatcherTypesList = ::testing::Types<WGDispCase<fp8_t, fp8_t, float, 16, 16, 128>,
WGDispCase<bf8_t, bf8_t, float, 16, 16, 128>,
WGDispCase<fp8_t, bf8_t, float, 16, 16, 128>,
WGDispCase<bf8_t, fp8_t, float, 16, 16, 128>>;
template <typename Case>
class WGScale16Test : public ::testing::Test
{
};
TYPED_TEST_SUITE(WGScale16Test, WGDispatcherTypesList);
TYPED_TEST(WGScale16Test, Scale16_16x16x128_UniformScale)
{
using Case = TypeParam;
using AType = typename Case::AType;
using BType = typename Case::BType;
using CType = typename Case::AccType;
constexpr index_t M = Case::MPerWave;
constexpr index_t N = Case::NPerWave;
constexpr index_t K = Case::KPerWave;
constexpr index_t NumScales = K / 16;
HostTensor<AType> A({M, K});
HostTensor<BType> B({N, K});
HostTensor<CType> C({M, N});
HostTensor<e8m0_t> sA({M, NumScales});
HostTensor<e8m0_t> sB({N, NumScales});
FillUniformDistribution<AType>{-5.f, 5.f}(A);
FillUniformDistribution<BType>{-5.f, 5.f}(B);
C.SetZero();
FillConstant<e8m0_t>{e8m0_t{2.f}}(sA);
FillConstant<e8m0_t>{e8m0_t{4.f}}(sB);
RunTest<Case, false>(A, B, sA, sB, C);
HostTensor<CType> C_ref({M, N});
C_ref.SetZero();
reference_mx_gemm<AType, BType, e8m0_t, e8m0_t, CType, CType>(
A, B.transpose(), C_ref, sA, sB.transpose());
EXPECT_TRUE(check_err(C, C_ref, "Scale16 uniform scale error."));
}
TYPED_TEST(WGScale16Test, Scale16_16x16x128_RandomDataRandomScales)
{
using Case = TypeParam;
using AType = typename Case::AType;
using BType = typename Case::BType;
using CType = typename Case::AccType;
constexpr index_t M = Case::MPerWave;
constexpr index_t N = Case::NPerWave;
constexpr index_t K = Case::KPerWave;
constexpr index_t NumScales = K / 16;
HostTensor<AType> A({M, K});
HostTensor<BType> B({N, K});
HostTensor<CType> C({M, N});
HostTensor<e8m0_t> sA({M, NumScales});
HostTensor<e8m0_t> sB({N, NumScales});
FillUniformDistribution<AType>{-5.f, 5.f, 42}(A);
FillUniformDistribution<BType>{-5.f, 5.f, 137}(B);
C.SetZero();
{
constexpr int bias = ck_tile::numeric_traits<e8m0_t>::bias;
std::mt19937 gen(9999);
std::uniform_int_distribution<int> dist(bias - 4, bias + 2);
for(auto& s : sA.mData)
s = e8m0_t(static_cast<typename e8m0_t::type>(dist(gen)));
for(auto& s : sB.mData)
s = e8m0_t(static_cast<typename e8m0_t::type>(dist(gen)));
}
RunTest<Case, false>(A, B, sA, sB, C);
HostTensor<CType> C_ref({M, N});
C_ref.SetZero();
reference_mx_gemm<AType, BType, e8m0_t, e8m0_t, CType, CType>(
A, B.transpose(), C_ref, sA, sB.transpose());
const float max_acc = *std::max_element(C_ref.mData.begin(), C_ref.mData.end());
const auto rtol = ck_tile::get_relative_threshold<AType, CType, CType>(K);
const auto atol = ck_tile::get_absolute_threshold<AType, CType, CType>(max_acc, K);
EXPECT_TRUE(check_err(C, C_ref, "Scale16 random data + random scales error.", rtol, atol));
}
TYPED_TEST(WGScale16Test, Scale16_16x16x128_TransposeC)
{
using Case = TypeParam;
using AType = typename Case::AType;
using BType = typename Case::BType;
using CType = typename Case::AccType;
constexpr index_t M = Case::MPerWave;
constexpr index_t N = Case::NPerWave;
constexpr index_t K = Case::KPerWave;
constexpr index_t NumScales = K / 16;
HostTensor<AType> A({M, K});
HostTensor<BType> B({N, K});
HostTensor<CType> C({M, N});
HostTensor<e8m0_t> sA({M, NumScales});
HostTensor<e8m0_t> sB({N, NumScales});
FillUniformDistribution<AType>{-5.f, 5.f, 77}(A);
FillUniformDistribution<BType>{-5.f, 5.f, 88}(B);
C.SetZero();
{
constexpr int bias = ck_tile::numeric_traits<e8m0_t>::bias;
std::mt19937 gen(5555);
std::uniform_int_distribution<int> dist(bias - 4, bias + 2);
for(auto& s : sA.mData)
s = e8m0_t(static_cast<typename e8m0_t::type>(dist(gen)));
for(auto& s : sB.mData)
s = e8m0_t(static_cast<typename e8m0_t::type>(dist(gen)));
}
RunTest<Case, true>(A, B, sA, sB, C);
HostTensor<CType> C_ref({M, N});
C_ref.SetZero();
reference_mx_gemm<AType, BType, e8m0_t, e8m0_t, CType, CType>(
A, B.transpose(), C_ref, sA, sB.transpose());
const float max_acc = *std::max_element(C_ref.mData.begin(), C_ref.mData.end());
const auto rtol = ck_tile::get_relative_threshold<AType, CType, CType>(K);
const auto atol = ck_tile::get_absolute_threshold<AType, CType, CType>(max_acc, K);
EXPECT_TRUE(check_err(C, C_ref, "Scale16 TransposeC error.", rtol, atol));
}

View File

@@ -0,0 +1,362 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
//
// Unit test for a 32x32x128 block-scaled GEMM built from the 16x16x128 scale16
// warp gemm (V_WMMA_SCALE16_F32_16X16X128_F8F6F4), modeled after the MX GEMM
// pipelines (BlockGemmARegBRegCRegV1 / mx_flatmm pipeline): all A/B sub-tiles and
// their per-block K-scales for the 2x2 block-level loop are preloaded into
// registers in advance, then consumed by the warp-gemm loop. OpSelA/OpSelB are
// threaded through the warp-gemm call as the pipeline does; for a 16x16 tile the
// hardware A/B sub-block index is 0, so per-M-block / per-N-block scale selection
// is realized by indexing the preloaded per-block scales (each int64 carries that
// block's 8 e8m0 K-scales).
#include <gtest/gtest.h>
#include "ck_tile/host.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
using namespace ck_tile;
// Distribution that stages one scale row per lane: NumRows rows mapped across the
// NumRows lanes of the wave (lane L holds row L's NumScales e8m0 K-scales = one int64).
// For a 32-row tile this yields lanes 0..15 -> rows 0..15, lanes 16..31 -> rows 16..31,
// which is exactly the layout the hardware selects between via SCL_OPSEL (OpSel).
template <index_t NumRows, index_t NumScales>
CK_TILE_DEVICE static constexpr auto MakeScaleDistribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NumRows>, sequence<NumScales>>,
tuple<sequence<1>>,
tuple<sequence<0>>,
sequence<2>,
sequence<0>>{});
}
// Pipeline-style kernel: preload all scales for the 2x2 block in advance, then
// run a 2x2 block-level loop over the 16x16x128 scale16 warp gemm. Each (mIter,
// nIter) sub-tile consumes its own independent per-M/per-N-block K-scales.
template <typename AType,
typename BType,
typename CType,
index_t M,
index_t N,
index_t K,
bool TransposeC>
struct WarpGemmScale16BlockLoopKernel
{
static constexpr int kBlockSize = 32;
static constexpr index_t ScaleBlockK = 16;
static constexpr index_t NumScales = K / ScaleBlockK;
static constexpr index_t MPerWarp = 16;
static constexpr index_t NPerWarp = 16;
static constexpr index_t MIter = M / MPerWarp;
static constexpr index_t NIter = N / NPerWarp;
__device__ void operator()(void* A, void* B, void* C, void* ScaleA, void* ScaleB) const
{
const auto a_view =
make_naive_tensor_view<address_space_enum::global>(static_cast<AType*>(A),
make_tuple(M, K),
make_tuple(K, number<1>{}),
number<K>{},
number<1>{});
const auto b_view =
make_naive_tensor_view<address_space_enum::global>(static_cast<BType*>(B),
make_tuple(N, K),
make_tuple(K, number<1>{}),
number<K>{},
number<1>{});
const auto c_view =
make_naive_tensor_view<address_space_enum::global>(static_cast<CType*>(C),
make_tuple(M, N),
make_tuple(N, number<1>{}),
number<N>{},
number<1>{});
const auto sa_view =
make_naive_tensor_view<address_space_enum::global>(static_cast<e8m0_t*>(ScaleA),
make_tuple(M, NumScales),
make_tuple(NumScales, number<1>{}),
number<NumScales>{},
number<1>{});
const auto sb_view =
make_naive_tensor_view<address_space_enum::global>(static_cast<e8m0_t*>(ScaleB),
make_tuple(N, NumScales),
make_tuple(NumScales, number<1>{}),
number<NumScales>{},
number<1>{});
using WarpGemm = WarpGemmDispatcher<
AType, // ADataType: A element type (fp8_t / bf8_t)
BType, // BDataType: B element type (fp8_t / bf8_t)
float, // AccDataType: accumulator type (F32)
MPerWarp, // MPerWave: warp-tile M (16 - native op size)
NPerWarp, // NPerWave: warp-tile N (16 - native op size)
K, // KPerWave: warp-tile K (128)
TransposeC, // TransposeC: use transposed-C distribution
false, // SwizzleA: A LDS swizzle layout (off)
false, // UseStructuredSparsity: 2:4 sparsity (off)
WGAttrNumAccessEnum::Default, // AttrNumAccessA: A num-access attribute
WGAttrNumAccessEnum::Default, // AttrNumAccessB: B num-access attribute
true>; // IsScale16: select the scale16 WMMA variant
constexpr auto a_dstr = typename WarpGemm::AWarpDstr{};
constexpr auto b_dstr = typename WarpGemm::BWarpDstr{};
constexpr auto c_dstr = typename WarpGemm::CWarpDstr{};
constexpr auto scale_dstr = MakeScaleDistribution<M, NumScales>();
// ---- Preload phase ----
// A/B sub-tiles for the 2x2 block loop. Scales are staged ONCE across all 32
// lanes (not per-block arrays): a single int64 A-scale and B-scale carry every
// row's/col's K-scales, and the per-block selection is done in hardware by
// OpSelA/OpSelB (SCL_OPSEL) -- lanes 0..15 for block 0, lanes 16..31 for block 1.
statically_indexed_array<decltype(load_tile(
make_tile_window(a_view,
make_tuple(number<MPerWarp>{}, number<K>{}),
make_multi_index(0, 0),
a_dstr))),
MIter>
a_tiles;
static_for<0, MIter, 1>{}([&](auto mIter) {
auto a_win = make_tile_window(a_view,
make_tuple(number<MPerWarp>{}, number<K>{}),
make_multi_index(mIter.value * MPerWarp, 0),
a_dstr);
a_tiles(mIter) = load_tile(a_win);
});
statically_indexed_array<decltype(load_tile(
make_tile_window(b_view,
make_tuple(number<NPerWarp>{}, number<K>{}),
make_multi_index(0, 0),
b_dstr))),
NIter>
b_tiles;
static_for<0, NIter, 1>{}([&](auto nIter) {
auto b_win = make_tile_window(b_view,
make_tuple(number<NPerWarp>{}, number<K>{}),
make_multi_index(nIter.value * NPerWarp, 0),
b_dstr);
b_tiles(nIter) = load_tile(b_win);
});
// Single 64-bit scale scalars: all M rows / N cols staged across the 32 lanes.
auto sa_win = make_tile_window(sa_view,
make_tuple(number<M>{}, number<NumScales>{}),
make_multi_index(0, 0),
scale_dstr);
const int64_t scale_a =
bit_cast<int64_t>(load_tile(sa_win)
.get_thread_buffer()
.template get_as<ext_vector_t<e8m0_t, NumScales>>()[number<0>{}]);
auto sb_win = make_tile_window(sb_view,
make_tuple(number<N>{}, number<NumScales>{}),
make_multi_index(0, 0),
scale_dstr);
const int64_t scale_b =
bit_cast<int64_t>(load_tile(sb_win)
.get_thread_buffer()
.template get_as<ext_vector_t<e8m0_t, NumScales>>()[number<0>{}]);
// ---- Compute phase ----
// 2x2 block-level loop. The same int64 scale scalars are passed to every call;
// the correct per-block K-scales are fetched in hardware via OpSelA/OpSelB, which
// select lanes 0..15 (block 0) vs lanes 16..31 (block 1) of the scale register.
static_for<0, MIter, 1>{}([&](auto mIter) {
static_for<0, NIter, 1>{}([&](auto nIter) {
auto c_win = make_tile_window(
c_view,
make_tuple(number<MPerWarp>{}, number<NPerWarp>{}),
make_multi_index(mIter.value * MPerWarp, nIter.value * NPerWarp),
c_dstr);
auto c_tile =
WarpGemm{}.template operator()<OpSelA<mIter.value>, OpSelB<nIter.value>>(
a_tiles(mIter), b_tiles(nIter), scale_a, scale_b);
store_tile(c_win, c_tile);
});
});
}
};
template <typename A, typename B, typename Acc, index_t M, index_t N, index_t K>
struct WGDispCase
{
using AType = A;
using BType = B;
using AccType = Acc;
static constexpr index_t MPerWave = M;
static constexpr index_t NPerWave = N;
static constexpr index_t KPerWave = K;
};
template <typename Case, bool TransposeC>
static void RunTest(const HostTensor<typename Case::AType>& A,
const HostTensor<typename Case::BType>& B,
const HostTensor<e8m0_t>& ScaleA,
const HostTensor<e8m0_t>& ScaleB,
HostTensor<typename Case::AccType>& C)
{
DeviceMem Ad(A), Bd(B), Cd(C), SAd(ScaleA), SBd(ScaleB);
dim3 grid(1), block{32};
using K = WarpGemmScale16BlockLoopKernel<typename Case::AType,
typename Case::BType,
typename Case::AccType,
Case::MPerWave,
Case::NPerWave,
Case::KPerWave,
TransposeC>;
(void)launch_kernel(stream_config{nullptr, true, 0, 0, 1},
make_kernel(K{},
grid,
block,
0,
Ad.GetDeviceBuffer(),
Bd.GetDeviceBuffer(),
Cd.GetDeviceBuffer(),
SAd.GetDeviceBuffer(),
SBd.GetDeviceBuffer()));
Cd.FromDevice(C.mData.data());
}
using WGDispatcherTypesList = ::testing::Types<WGDispCase<fp8_t, fp8_t, float, 32, 32, 128>,
WGDispCase<bf8_t, bf8_t, float, 32, 32, 128>,
WGDispCase<fp8_t, bf8_t, float, 32, 32, 128>,
WGDispCase<bf8_t, fp8_t, float, 32, 32, 128>>;
template <typename Case>
class WGScale16Test : public ::testing::Test
{
};
TYPED_TEST_SUITE(WGScale16Test, WGDispatcherTypesList);
TYPED_TEST(WGScale16Test, Scale16_32x32x128_UniformScale)
{
using Case = TypeParam;
using AType = typename Case::AType;
using BType = typename Case::BType;
using CType = typename Case::AccType;
constexpr index_t M = Case::MPerWave;
constexpr index_t N = Case::NPerWave;
constexpr index_t K = Case::KPerWave;
constexpr index_t NumScales = K / 16;
HostTensor<AType> A({M, K});
HostTensor<BType> B({N, K});
HostTensor<CType> C({M, N});
HostTensor<e8m0_t> sA({M, NumScales});
HostTensor<e8m0_t> sB({N, NumScales});
FillUniformDistribution<AType>{-5.f, 5.f}(A);
FillUniformDistribution<BType>{-5.f, 5.f}(B);
C.SetZero();
FillConstant<e8m0_t>{e8m0_t{2.f}}(sA);
FillConstant<e8m0_t>{e8m0_t{4.f}}(sB);
RunTest<Case, false>(A, B, sA, sB, C);
HostTensor<CType> C_ref({M, N});
C_ref.SetZero();
reference_mx_gemm<AType, BType, e8m0_t, e8m0_t, CType, CType>(
A, B.transpose(), C_ref, sA, sB.transpose());
EXPECT_TRUE(check_err(C, C_ref, "Scale16 32x32x128 uniform scale error."));
}
TYPED_TEST(WGScale16Test, Scale16_32x32x128_RandomDataRandomScales)
{
using Case = TypeParam;
using AType = typename Case::AType;
using BType = typename Case::BType;
using CType = typename Case::AccType;
constexpr index_t M = Case::MPerWave;
constexpr index_t N = Case::NPerWave;
constexpr index_t K = Case::KPerWave;
constexpr index_t NumScales = K / 16;
HostTensor<AType> A({M, K});
HostTensor<BType> B({N, K});
HostTensor<CType> C({M, N});
HostTensor<e8m0_t> sA({M, NumScales});
HostTensor<e8m0_t> sB({N, NumScales});
FillUniformDistribution<AType>{-5.f, 5.f, 42}(A);
FillUniformDistribution<BType>{-5.f, 5.f, 137}(B);
C.SetZero();
{
constexpr int bias = ck_tile::numeric_traits<e8m0_t>::bias;
std::mt19937 gen(9999);
std::uniform_int_distribution<int> dist(bias - 4, bias + 2);
for(auto& s : sA.mData)
s = e8m0_t(static_cast<typename e8m0_t::type>(dist(gen)));
for(auto& s : sB.mData)
s = e8m0_t(static_cast<typename e8m0_t::type>(dist(gen)));
}
RunTest<Case, false>(A, B, sA, sB, C);
HostTensor<CType> C_ref({M, N});
C_ref.SetZero();
reference_mx_gemm<AType, BType, e8m0_t, e8m0_t, CType, CType>(
A, B.transpose(), C_ref, sA, sB.transpose());
const float max_acc = *std::max_element(C_ref.mData.begin(), C_ref.mData.end());
const auto rtol = ck_tile::get_relative_threshold<AType, CType, CType>(K);
const auto atol = ck_tile::get_absolute_threshold<AType, CType, CType>(max_acc, K);
EXPECT_TRUE(
check_err(C, C_ref, "Scale16 32x32x128 random data + random scales error.", rtol, atol));
}
TYPED_TEST(WGScale16Test, Scale16_32x32x128_TransposeC)
{
using Case = TypeParam;
using AType = typename Case::AType;
using BType = typename Case::BType;
using CType = typename Case::AccType;
constexpr index_t M = Case::MPerWave;
constexpr index_t N = Case::NPerWave;
constexpr index_t K = Case::KPerWave;
constexpr index_t NumScales = K / 16;
HostTensor<AType> A({M, K});
HostTensor<BType> B({N, K});
HostTensor<CType> C({M, N});
HostTensor<e8m0_t> sA({M, NumScales});
HostTensor<e8m0_t> sB({N, NumScales});
FillUniformDistribution<AType>{-5.f, 5.f, 77}(A);
FillUniformDistribution<BType>{-5.f, 5.f, 88}(B);
C.SetZero();
{
constexpr int bias = ck_tile::numeric_traits<e8m0_t>::bias;
std::mt19937 gen(5555);
std::uniform_int_distribution<int> dist(bias - 4, bias + 2);
for(auto& s : sA.mData)
s = e8m0_t(static_cast<typename e8m0_t::type>(dist(gen)));
for(auto& s : sB.mData)
s = e8m0_t(static_cast<typename e8m0_t::type>(dist(gen)));
}
RunTest<Case, true>(A, B, sA, sB, C);
HostTensor<CType> C_ref({M, N});
C_ref.SetZero();
reference_mx_gemm<AType, BType, e8m0_t, e8m0_t, CType, CType>(
A, B.transpose(), C_ref, sA, sB.transpose());
const float max_acc = *std::max_element(C_ref.mData.begin(), C_ref.mData.end());
const auto rtol = ck_tile::get_relative_threshold<AType, CType, CType>(K);
const auto atol = ck_tile::get_absolute_threshold<AType, CType, CType>(max_acc, K);
EXPECT_TRUE(check_err(C, C_ref, "Scale16 32x32x128 TransposeC error.", rtol, atol));
}