mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-11 08:48:45 +00:00
[rocm-libraries] ROCm/rocm-libraries#7725 (commit eef7e12)
[GFX1250][CK_TILE] Add scale16 warp gemm unit tests ## Summary - Add scale16 WMMA intrinsic overloads and int64_t forwarding to warp gemm layers for gfx1250 - Add comprehensive wave-level unit tests for scale16 warp gemm (16x16x128 and 32x32x128 tile sizes) - Test all fp8/bf8 type combinations and TransposeC variants - Fix WarpGemm wrapper for non-uniform scale16 configurations Stacked on #7724 (FillUniformScaleDistribution / MX GEMM scale init). Pipeline enablement follows in the next PR.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
45a8f96c66
commit
e01603bc31
@@ -154,6 +154,19 @@ struct CTransposedWarpDstrEncodingTrait
|
||||
typename Impl::kCTYs2RHsMinor>;
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
template <typename T, typename = void>
|
||||
struct mx_type_enable_or_void
|
||||
{
|
||||
using type = void;
|
||||
};
|
||||
template <typename T>
|
||||
struct mx_type_enable_or_void<T, std::void_t<typename T::MXTypeEnableType>>
|
||||
{
|
||||
using type = typename T::MXTypeEnableType;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template <typename WarpGemmAttributeWmmaImpl_,
|
||||
bool kTransC = false,
|
||||
WGAttrNumAccessEnum AttrNumAccessA = WGAttrNumAccessEnum::Single,
|
||||
@@ -162,18 +175,21 @@ struct WarpGemmAttributeWmma
|
||||
{
|
||||
using Impl = remove_cvref_t<WarpGemmAttributeWmmaImpl_>;
|
||||
|
||||
// When kTransC is true and A/B types differ, we need an impl with swapped types
|
||||
using TransposedImpl =
|
||||
std::conditional_t<kTransC &&
|
||||
!std::is_same_v<typename Impl::ADataType, typename Impl::BDataType>,
|
||||
WarpGemmAttributeWmmaImpl<WmmaTraits<typename Impl::TraitsType::ArchType,
|
||||
typename Impl::BDataType,
|
||||
typename Impl::ADataType,
|
||||
typename Impl::CDataType,
|
||||
Impl::kM,
|
||||
Impl::kN,
|
||||
Impl::kK>>,
|
||||
Impl>;
|
||||
// When kTransC is true and A/B types differ, we need an impl with swapped types.
|
||||
// Propagate MXTypeEnable (e.g., WmmaScale16Tag) so the transposed impl uses the
|
||||
// same WmmaTraits specialization family.
|
||||
using TransposedImpl = std::conditional_t<
|
||||
kTransC && !std::is_same_v<typename Impl::ADataType, typename Impl::BDataType>,
|
||||
WarpGemmAttributeWmmaImpl<
|
||||
WmmaTraits<typename Impl::TraitsType::ArchType,
|
||||
typename Impl::BDataType,
|
||||
typename Impl::ADataType,
|
||||
typename Impl::CDataType,
|
||||
Impl::kM,
|
||||
Impl::kN,
|
||||
Impl::kK,
|
||||
typename detail::mx_type_enable_or_void<typename Impl::TraitsType>::type>>,
|
||||
Impl>;
|
||||
|
||||
using ADataType = typename Impl::ADataType;
|
||||
using BDataType = typename Impl::BDataType;
|
||||
|
||||
@@ -204,6 +204,13 @@ template <typename AType, typename BType>
|
||||
using WarpGemmAttributeWmmaImpl_f32_16x16x128_f8f6f4 =
|
||||
WarpGemmAttributeWmmaImpl<WmmaTraits<gfx125_t, AType, BType, float, 16, 16, 128>>;
|
||||
|
||||
// WmmaScale16Tag (declared above) is passed as MXTypeEnable to WmmaTraits to select scale16
|
||||
// specializations. These override kAK1PerLane=16 (-> sequence<4,2,16>) and use int64_t scales
|
||||
// for V_WMMA_SCALE16_F32_16X16X128_F8F6F4, vs the default layout / int32_t.
|
||||
template <typename AType, typename BType>
|
||||
using WarpGemmAttributeWmmaImpl_f32_16x16x128_f8f6f4_scale16 = WarpGemmAttributeWmmaImpl<
|
||||
WmmaTraits<gfx125_t, AType, BType, float, 16, 16, 128, WmmaScale16Tag>>;
|
||||
|
||||
template <typename AType, typename BType>
|
||||
using WarpGemmAttributeWmmaImpl_f32_32x32x128_f8f6f4 =
|
||||
WarpGemmAttributeWmmaImpl<WmmaTraits<gfx125_t, AType, BType, float, 32, 32, 128>>;
|
||||
|
||||
@@ -427,6 +427,82 @@ struct WmmaTraits<gfx125_t, bf8_t, bf8_t, fp16_t, 16, 16, 64>
|
||||
}
|
||||
};
|
||||
|
||||
// f8f6f4 specialization - GFX125
|
||||
enum F8F6F4OpDataTypeEnum
|
||||
{
|
||||
E4M3, // 0x0
|
||||
E5M2, // 0x1
|
||||
E2M3, // 0x2
|
||||
E3M2, // 0x3
|
||||
E2M1, // 0x4
|
||||
};
|
||||
|
||||
// Traits for MX data types used in f8f6f4 intrinsics
|
||||
template <typename T>
|
||||
struct MXDataTypeTrait;
|
||||
|
||||
template <>
|
||||
struct MXDataTypeTrait<fp8_t>
|
||||
{
|
||||
static constexpr F8F6F4OpDataTypeEnum OpDataType = F8F6F4OpDataTypeEnum::E4M3;
|
||||
using VecType = int32x16_t;
|
||||
|
||||
CK_TILE_DEVICE static int32x16_t to_wmma_vec(const int32x16_t& vec) { return vec; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MXDataTypeTrait<bf8_t>
|
||||
{
|
||||
static constexpr F8F6F4OpDataTypeEnum OpDataType = F8F6F4OpDataTypeEnum::E5M2;
|
||||
using VecType = int32x16_t;
|
||||
|
||||
CK_TILE_DEVICE static int32x16_t to_wmma_vec(const int32x16_t& vec) { return vec; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MXDataTypeTrait<pk_fp4_t>
|
||||
{
|
||||
static constexpr F8F6F4OpDataTypeEnum OpDataType = F8F6F4OpDataTypeEnum::E2M1;
|
||||
using VecType = int32x8_t;
|
||||
|
||||
CK_TILE_DEVICE static int32x16_t to_wmma_vec(const int32x8_t& vec)
|
||||
{
|
||||
return int32x16_t{
|
||||
vec[0], vec[1], vec[2], vec[3], vec[4], vec[5], vec[6], vec[7], 0, 0, 0, 0, 0, 0, 0, 0};
|
||||
}
|
||||
};
|
||||
|
||||
// pk_fp6x16_t (legacy): 16 fp6 e2m3 values packed into 3 int32 (96 bits).
|
||||
// At 16x16x128 each lane holds 64 fp6 elements = 4 packs = 12 int32
|
||||
// (f6x16xN_tt<4, f6_kind::fp6>, whose storage is int32_t data[12]),
|
||||
// padded with 4 zero lanes to fit the 16-wide f8f6f4 wmma input.
|
||||
template <>
|
||||
struct MXDataTypeTrait<pk_fp6x16_t>
|
||||
{
|
||||
static constexpr F8F6F4OpDataTypeEnum OpDataType = F8F6F4OpDataTypeEnum::E2M3;
|
||||
using VecType = f6x16xN_tt<4, f6_kind::fp6>;
|
||||
|
||||
CK_TILE_DEVICE static int32x16_t to_wmma_vec(const f6x16xN_tt<4, f6_kind::fp6>& vec)
|
||||
{
|
||||
return int32x16_t{vec.data[0],
|
||||
vec.data[1],
|
||||
vec.data[2],
|
||||
vec.data[3],
|
||||
vec.data[4],
|
||||
vec.data[5],
|
||||
vec.data[6],
|
||||
vec.data[7],
|
||||
vec.data[8],
|
||||
vec.data[9],
|
||||
vec.data[10],
|
||||
vec.data[11],
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0};
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct WmmaTraits<gfx125_t, bf8_t, bf8_t, float, 16, 16, 128>
|
||||
: WmmaTraitsBase<gfx12_t, bf8_t, bf8_t, float, 128>
|
||||
@@ -508,6 +584,232 @@ struct WmmaTraits<gfx125_t, bf8_t, fp8_t, float, 16, 16, 128>
|
||||
}
|
||||
};
|
||||
|
||||
// scale16 specializations: fp8xfp8, bf8xbf8, fp8xbf8, bf8xfp8
|
||||
// Override kAK1PerLane/kBK1PerLane to 16 for scale16 register layout -> sequence<4,2,16>
|
||||
template <>
|
||||
struct WmmaTraits<gfx125_t, fp8_t, fp8_t, float, 16, 16, 128, WmmaScale16Tag>
|
||||
: WmmaTraitsBase<gfx12_t, fp8_t, fp8_t, float, 128>
|
||||
{
|
||||
using ArchType = gfx125_t;
|
||||
using MXTypeEnableType = WmmaScale16Tag;
|
||||
|
||||
static constexpr index_t kAK1PerLane = 16;
|
||||
static constexpr index_t kAK0PerLane = kK / (kAK1PerLane * kABKLane);
|
||||
static constexpr index_t kBK1PerLane = 16;
|
||||
static constexpr index_t kBK0PerLane = kK / (kBK1PerLane * kABKLane);
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType&, const BVecType&, const CVecType&)
|
||||
{
|
||||
static_assert(sizeof...(Params) < 0, "scale16 WmmaTraits requires int64_t scale arguments");
|
||||
return CVecType{0};
|
||||
}
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType& a_vec,
|
||||
const int64_t& a_scale,
|
||||
const BVecType& b_vec,
|
||||
const int64_t& b_scale,
|
||||
const CVecType& c_vec)
|
||||
{
|
||||
#ifdef __gfx125__
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
using ATraits = MXDataTypeTrait<fp8_t>;
|
||||
using BTraits = MXDataTypeTrait<fp8_t>;
|
||||
return __builtin_amdgcn_wmma_scale16_f32_16x16x128_f8f6f4(
|
||||
ATraits::OpDataType,
|
||||
ATraits::to_wmma_vec(bit_cast<typename ATraits::VecType>(a_vec)),
|
||||
BTraits::OpDataType,
|
||||
BTraits::to_wmma_vec(bit_cast<typename BTraits::VecType>(b_vec)),
|
||||
0,
|
||||
bit_cast<fp32x8_t>(c_vec),
|
||||
P::op_sel_a,
|
||||
P::scale_a,
|
||||
a_scale,
|
||||
P::op_sel_b,
|
||||
P::scale_b,
|
||||
b_scale,
|
||||
0,
|
||||
0);
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = a_scale;
|
||||
ck_tile::ignore = b_vec;
|
||||
ck_tile::ignore = b_scale;
|
||||
ck_tile::ignore = c_vec;
|
||||
return CVecType{0};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct WmmaTraits<gfx125_t, bf8_t, bf8_t, float, 16, 16, 128, WmmaScale16Tag>
|
||||
: WmmaTraitsBase<gfx12_t, bf8_t, bf8_t, float, 128>
|
||||
{
|
||||
using ArchType = gfx125_t;
|
||||
using MXTypeEnableType = WmmaScale16Tag;
|
||||
|
||||
static constexpr index_t kAK1PerLane = 16;
|
||||
static constexpr index_t kAK0PerLane = kK / (kAK1PerLane * kABKLane);
|
||||
static constexpr index_t kBK1PerLane = 16;
|
||||
static constexpr index_t kBK0PerLane = kK / (kBK1PerLane * kABKLane);
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType&, const BVecType&, const CVecType&)
|
||||
{
|
||||
static_assert(sizeof...(Params) < 0, "scale16 WmmaTraits requires int64_t scale arguments");
|
||||
return CVecType{0};
|
||||
}
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType& a_vec,
|
||||
const int64_t& a_scale,
|
||||
const BVecType& b_vec,
|
||||
const int64_t& b_scale,
|
||||
const CVecType& c_vec)
|
||||
{
|
||||
#ifdef __gfx125__
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
using ATraits = MXDataTypeTrait<bf8_t>;
|
||||
using BTraits = MXDataTypeTrait<bf8_t>;
|
||||
return __builtin_amdgcn_wmma_scale16_f32_16x16x128_f8f6f4(
|
||||
ATraits::OpDataType,
|
||||
ATraits::to_wmma_vec(bit_cast<typename ATraits::VecType>(a_vec)),
|
||||
BTraits::OpDataType,
|
||||
BTraits::to_wmma_vec(bit_cast<typename BTraits::VecType>(b_vec)),
|
||||
0,
|
||||
bit_cast<fp32x8_t>(c_vec),
|
||||
P::op_sel_a,
|
||||
P::scale_a,
|
||||
a_scale,
|
||||
P::op_sel_b,
|
||||
P::scale_b,
|
||||
b_scale,
|
||||
0,
|
||||
0);
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = a_scale;
|
||||
ck_tile::ignore = b_vec;
|
||||
ck_tile::ignore = b_scale;
|
||||
ck_tile::ignore = c_vec;
|
||||
return CVecType{0};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct WmmaTraits<gfx125_t, fp8_t, bf8_t, float, 16, 16, 128, WmmaScale16Tag>
|
||||
: WmmaTraitsBase<gfx12_t, fp8_t, bf8_t, float, 128>
|
||||
{
|
||||
using ArchType = gfx125_t;
|
||||
using MXTypeEnableType = WmmaScale16Tag;
|
||||
|
||||
static constexpr index_t kAK1PerLane = 16;
|
||||
static constexpr index_t kAK0PerLane = kK / (kAK1PerLane * kABKLane);
|
||||
static constexpr index_t kBK1PerLane = 16;
|
||||
static constexpr index_t kBK0PerLane = kK / (kBK1PerLane * kABKLane);
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType&, const BVecType&, const CVecType&)
|
||||
{
|
||||
static_assert(sizeof...(Params) < 0, "scale16 WmmaTraits requires int64_t scale arguments");
|
||||
return CVecType{0};
|
||||
}
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType& a_vec,
|
||||
const int64_t& a_scale,
|
||||
const BVecType& b_vec,
|
||||
const int64_t& b_scale,
|
||||
const CVecType& c_vec)
|
||||
{
|
||||
#ifdef __gfx125__
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
using ATraits = MXDataTypeTrait<fp8_t>;
|
||||
using BTraits = MXDataTypeTrait<bf8_t>;
|
||||
return __builtin_amdgcn_wmma_scale16_f32_16x16x128_f8f6f4(
|
||||
ATraits::OpDataType,
|
||||
ATraits::to_wmma_vec(bit_cast<typename ATraits::VecType>(a_vec)),
|
||||
BTraits::OpDataType,
|
||||
BTraits::to_wmma_vec(bit_cast<typename BTraits::VecType>(b_vec)),
|
||||
0,
|
||||
bit_cast<fp32x8_t>(c_vec),
|
||||
P::op_sel_a,
|
||||
P::scale_a,
|
||||
a_scale,
|
||||
P::op_sel_b,
|
||||
P::scale_b,
|
||||
b_scale,
|
||||
0,
|
||||
0);
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = a_scale;
|
||||
ck_tile::ignore = b_vec;
|
||||
ck_tile::ignore = b_scale;
|
||||
ck_tile::ignore = c_vec;
|
||||
return CVecType{0};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct WmmaTraits<gfx125_t, bf8_t, fp8_t, float, 16, 16, 128, WmmaScale16Tag>
|
||||
: WmmaTraitsBase<gfx12_t, bf8_t, fp8_t, float, 128>
|
||||
{
|
||||
using ArchType = gfx125_t;
|
||||
using MXTypeEnableType = WmmaScale16Tag;
|
||||
|
||||
static constexpr index_t kAK1PerLane = 16;
|
||||
static constexpr index_t kAK0PerLane = kK / (kAK1PerLane * kABKLane);
|
||||
static constexpr index_t kBK1PerLane = 16;
|
||||
static constexpr index_t kBK0PerLane = kK / (kBK1PerLane * kABKLane);
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType&, const BVecType&, const CVecType&)
|
||||
{
|
||||
static_assert(sizeof...(Params) < 0, "scale16 WmmaTraits requires int64_t scale arguments");
|
||||
return CVecType{0};
|
||||
}
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType& a_vec,
|
||||
const int64_t& a_scale,
|
||||
const BVecType& b_vec,
|
||||
const int64_t& b_scale,
|
||||
const CVecType& c_vec)
|
||||
{
|
||||
#ifdef __gfx125__
|
||||
using P = WarpGemmParamsParser<Params...>;
|
||||
using ATraits = MXDataTypeTrait<bf8_t>;
|
||||
using BTraits = MXDataTypeTrait<fp8_t>;
|
||||
return __builtin_amdgcn_wmma_scale16_f32_16x16x128_f8f6f4(
|
||||
ATraits::OpDataType,
|
||||
ATraits::to_wmma_vec(bit_cast<typename ATraits::VecType>(a_vec)),
|
||||
BTraits::OpDataType,
|
||||
BTraits::to_wmma_vec(bit_cast<typename BTraits::VecType>(b_vec)),
|
||||
0,
|
||||
bit_cast<fp32x8_t>(c_vec),
|
||||
P::op_sel_a,
|
||||
P::scale_a,
|
||||
a_scale,
|
||||
P::op_sel_b,
|
||||
P::scale_b,
|
||||
b_scale,
|
||||
0,
|
||||
0);
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = a_scale;
|
||||
ck_tile::ignore = b_vec;
|
||||
ck_tile::ignore = b_scale;
|
||||
ck_tile::ignore = c_vec;
|
||||
return CVecType{0};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
// 32x16x128 f4 specialization - GFX125
|
||||
template <>
|
||||
struct WmmaTraits<gfx125_t, pk_fp4_t, pk_fp4_t, float, 32, 16, 128>
|
||||
@@ -648,82 +950,6 @@ struct WmmaTraits<gfx125_t, pk_fp4_t, pk_fp4_t, float, 32, 32, 128, WmmaScale16T
|
||||
{
|
||||
};
|
||||
|
||||
// f8f6f4 specialization - GFX125
|
||||
enum F8F6F4OpDataTypeEnum
|
||||
{
|
||||
E4M3, // 0x0
|
||||
E5M2, // 0x1
|
||||
E2M3, // 0x2
|
||||
E3M2, // 0x3
|
||||
E2M1, // 0x4
|
||||
};
|
||||
|
||||
// Traits for MX data types used in f8f6f4 intrinsics
|
||||
template <typename T>
|
||||
struct MXDataTypeTrait;
|
||||
|
||||
template <>
|
||||
struct MXDataTypeTrait<fp8_t>
|
||||
{
|
||||
static constexpr F8F6F4OpDataTypeEnum OpDataType = F8F6F4OpDataTypeEnum::E4M3;
|
||||
using VecType = int32x16_t;
|
||||
|
||||
CK_TILE_DEVICE static int32x16_t to_wmma_vec(const int32x16_t& vec) { return vec; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MXDataTypeTrait<bf8_t>
|
||||
{
|
||||
static constexpr F8F6F4OpDataTypeEnum OpDataType = F8F6F4OpDataTypeEnum::E5M2;
|
||||
using VecType = int32x16_t;
|
||||
|
||||
CK_TILE_DEVICE static int32x16_t to_wmma_vec(const int32x16_t& vec) { return vec; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MXDataTypeTrait<pk_fp4_t>
|
||||
{
|
||||
static constexpr F8F6F4OpDataTypeEnum OpDataType = F8F6F4OpDataTypeEnum::E2M1;
|
||||
using VecType = int32x8_t;
|
||||
|
||||
CK_TILE_DEVICE static int32x16_t to_wmma_vec(const int32x8_t& vec)
|
||||
{
|
||||
return int32x16_t{
|
||||
vec[0], vec[1], vec[2], vec[3], vec[4], vec[5], vec[6], vec[7], 0, 0, 0, 0, 0, 0, 0, 0};
|
||||
}
|
||||
};
|
||||
|
||||
// pk_fp6x16_t (legacy): 16 fp6 e2m3 values packed into 3 int32 (96 bits).
|
||||
// At 16x16x128 each lane holds 64 fp6 elements = 4 packs = 12 int32
|
||||
// (f6x16xN_tt<4, f6_kind::fp6>, whose storage is int32_t data[12]),
|
||||
// padded with 4 zero lanes to fit the 16-wide f8f6f4 wmma input.
|
||||
template <>
|
||||
struct MXDataTypeTrait<pk_fp6x16_t>
|
||||
{
|
||||
static constexpr F8F6F4OpDataTypeEnum OpDataType = F8F6F4OpDataTypeEnum::E2M3;
|
||||
using VecType = f6x16xN_tt<4, f6_kind::fp6>;
|
||||
|
||||
CK_TILE_DEVICE static int32x16_t to_wmma_vec(const f6x16xN_tt<4, f6_kind::fp6>& vec)
|
||||
{
|
||||
return int32x16_t{vec.data[0],
|
||||
vec.data[1],
|
||||
vec.data[2],
|
||||
vec.data[3],
|
||||
vec.data[4],
|
||||
vec.data[5],
|
||||
vec.data[6],
|
||||
vec.data[7],
|
||||
vec.data[8],
|
||||
vec.data[9],
|
||||
vec.data[10],
|
||||
vec.data[11],
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0};
|
||||
}
|
||||
};
|
||||
|
||||
// Unified WmmaTraits for f8f6f4 combinations
|
||||
template <typename AType, typename BType>
|
||||
struct WmmaTraits<
|
||||
|
||||
@@ -240,6 +240,9 @@ template<bool TransposeC, WGAttrNumAccessEnum AttrNumAccess> struct Dispatcher<b
|
||||
|
||||
// F8F6F4 Mixed precision cases
|
||||
template<typename A, typename B, bool TransposeC, WGAttrNumAccessEnum AttrNumAccessA, WGAttrNumAccessEnum AttrNumAccessB> struct Dispatcher<A, B, float, 16, 16, 128, TransposeC, false, false, AttrNumAccessA, AttrNumAccessB> : WmmaTag { using Type = WarpGemmWmma_f32_16x16x128_f8f6f4<A, B, TransposeC, AttrNumAccessA, AttrNumAccessB>; };
|
||||
|
||||
// F8F6F4 Scale16 (IsScale16=true)
|
||||
template<typename A, typename B, bool TransposeC, WGAttrNumAccessEnum AttrNumAccessA, WGAttrNumAccessEnum AttrNumAccessB> struct Dispatcher<A, B, float, 16, 16, 128, TransposeC, false, false, AttrNumAccessA, AttrNumAccessB, true> : WmmaTag { using Type = WarpGemmWmma_f32_16x16x128_f8f6f4_scale16<A, B, TransposeC, AttrNumAccessA, AttrNumAccessB>; };
|
||||
#else
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 64, false> { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8<>; };
|
||||
template<> struct Dispatcher<bf8_t, bf8_t, float, 16, 16, 64, false> { using Type = WarpGemmMfma_f32_16x16x64_bf8_bf8; };
|
||||
|
||||
@@ -216,6 +216,17 @@ using WarpGemmWmma_f32_16x16x128_f8f6f4 =
|
||||
AttrNumAccessA,
|
||||
AttrNumAccessB>>;
|
||||
|
||||
template <typename AType,
|
||||
typename BType,
|
||||
bool kTransC,
|
||||
WGAttrNumAccessEnum AttrNumAccessA = WGAttrNumAccessEnum::Default,
|
||||
WGAttrNumAccessEnum AttrNumAccessB = WGAttrNumAccessEnum::Default>
|
||||
using WarpGemmWmma_f32_16x16x128_f8f6f4_scale16 = WarpGemmImpl<
|
||||
WarpGemmAttributeWmma<WarpGemmAttributeWmmaImpl_f32_16x16x128_f8f6f4_scale16<AType, BType>,
|
||||
kTransC,
|
||||
AttrNumAccessA,
|
||||
AttrNumAccessB>>;
|
||||
|
||||
template <typename AType,
|
||||
typename BType,
|
||||
bool kTransC,
|
||||
|
||||
@@ -4,8 +4,14 @@
|
||||
if(GPU_TARGETS MATCHES "gfx95")
|
||||
add_gtest_executable(test_ck_tile_wg_16x16x128_fp4 test_f32_16x16x128_fp4.cpp)
|
||||
endif()
|
||||
|
||||
|
||||
# Scale16 warp gemm tests for V_WMMA_SCALE16_F32_16X16X128_F8F6F4.
|
||||
# Each test covers 4 type combos (fp8xfp8, bf8xbf8, fp8xbf8, bf8xfp8) x 3 cases
|
||||
# (uniform scale, random scale, TransposeC) = 12 tests.
|
||||
# 16x16: single warp gemm call. 32x32: pipeline-style 2x2 block loop with per-block scales.
|
||||
if(GPU_TARGETS MATCHES "gfx125")
|
||||
add_gtest_executable(test_ck_tile_wg_16x16x128_fp8_scale16 test_f32_16x16x128_fp8_scale16.cpp)
|
||||
add_gtest_executable(test_ck_tile_wg_32x32x128_fp8_scale16 test_f32_32x32x128_fp8_scale16.cpp)
|
||||
add_gtest_executable(test_ck_tile_wmma_bf16_16x16x32_gfx1250 test_wmma_bf16_16x16x32_gfx1250.cpp)
|
||||
add_gtest_executable(test_ck_tile_wg_32x16x128_fp4 test_f32_32x16x128_fp4.cpp)
|
||||
endif()
|
||||
|
||||
299
test/ck_tile/warp_gemm/test_f32_16x16x128_fp8_scale16.cpp
Normal file
299
test/ck_tile/warp_gemm/test_f32_16x16x128_fp8_scale16.cpp
Normal file
@@ -0,0 +1,299 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
|
||||
using namespace ck_tile;
|
||||
|
||||
template <index_t NumScales>
|
||||
CK_TILE_DEVICE static constexpr auto MakeScaleDistribution()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<2>,
|
||||
tuple<sequence<16>, sequence<NumScales>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 0>>,
|
||||
sequence<2>,
|
||||
sequence<0>>{});
|
||||
}
|
||||
|
||||
// Scale16 kernel using the WarpGemm wrapper (Layer 4).
|
||||
// WarpGemmWmma_f32_16x16x128_f8f6f4_scale16 produces AWarpDstr with sequence<4,2,16>,
|
||||
// matching the hardware's 1-scale-byte-per-16-K-elements register layout.
|
||||
template <typename AType,
|
||||
typename BType,
|
||||
typename CType,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
bool TransposeC>
|
||||
struct WarpGemmScale16Kernel
|
||||
{
|
||||
static constexpr int kBlockSize = 32;
|
||||
static constexpr index_t ScaleBlockK = 16;
|
||||
static constexpr index_t NumScales = K / ScaleBlockK;
|
||||
|
||||
__device__ void operator()(void* A, void* B, void* C, void* ScaleA, void* ScaleB) const
|
||||
{
|
||||
const auto a_view =
|
||||
make_naive_tensor_view<address_space_enum::global>(static_cast<AType*>(A),
|
||||
make_tuple(M, K),
|
||||
make_tuple(K, number<1>{}),
|
||||
number<K>{},
|
||||
number<1>{});
|
||||
const auto b_view =
|
||||
make_naive_tensor_view<address_space_enum::global>(static_cast<BType*>(B),
|
||||
make_tuple(N, K),
|
||||
make_tuple(K, number<1>{}),
|
||||
number<K>{},
|
||||
number<1>{});
|
||||
const auto c_view =
|
||||
make_naive_tensor_view<address_space_enum::global>(static_cast<CType*>(C),
|
||||
make_tuple(M, N),
|
||||
make_tuple(N, number<1>{}),
|
||||
number<N>{},
|
||||
number<1>{});
|
||||
const auto sa_view =
|
||||
make_naive_tensor_view<address_space_enum::global>(static_cast<e8m0_t*>(ScaleA),
|
||||
make_tuple(M, NumScales),
|
||||
make_tuple(NumScales, number<1>{}),
|
||||
number<NumScales>{},
|
||||
number<1>{});
|
||||
const auto sb_view =
|
||||
make_naive_tensor_view<address_space_enum::global>(static_cast<e8m0_t*>(ScaleB),
|
||||
make_tuple(N, NumScales),
|
||||
make_tuple(NumScales, number<1>{}),
|
||||
number<NumScales>{},
|
||||
number<1>{});
|
||||
|
||||
using WarpGemm = WarpGemmDispatcher<AType,
|
||||
BType,
|
||||
float,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
TransposeC,
|
||||
false,
|
||||
false,
|
||||
WGAttrNumAccessEnum::Default,
|
||||
WGAttrNumAccessEnum::Default,
|
||||
true>;
|
||||
constexpr auto a_dstr = typename WarpGemm::AWarpDstr{};
|
||||
constexpr auto b_dstr = typename WarpGemm::BWarpDstr{};
|
||||
constexpr auto c_dstr = typename WarpGemm::CWarpDstr{};
|
||||
constexpr auto scale_dstr = MakeScaleDistribution<NumScales>();
|
||||
|
||||
auto a_win = make_tile_window(
|
||||
a_view, make_tuple(number<M>{}, number<K>{}), make_multi_index(0, 0), a_dstr);
|
||||
auto b_win = make_tile_window(
|
||||
b_view, make_tuple(number<N>{}, number<K>{}), make_multi_index(0, 0), b_dstr);
|
||||
auto c_win = make_tile_window(
|
||||
c_view, make_tuple(number<M>{}, number<N>{}), make_multi_index(0, 0), c_dstr);
|
||||
auto sa_win = make_tile_window(sa_view,
|
||||
make_tuple(number<M>{}, number<NumScales>{}),
|
||||
make_multi_index(0, 0),
|
||||
scale_dstr);
|
||||
auto sb_win = make_tile_window(sb_view,
|
||||
make_tuple(number<N>{}, number<NumScales>{}),
|
||||
make_multi_index(0, 0),
|
||||
scale_dstr);
|
||||
|
||||
auto a_tile = load_tile(a_win);
|
||||
auto b_tile = load_tile(b_win);
|
||||
auto sa_tile = load_tile(sa_win);
|
||||
auto sb_tile = load_tile(sb_win);
|
||||
|
||||
int64_t scale_a =
|
||||
bit_cast<int64_t>(sa_tile.get_thread_buffer()
|
||||
.template get_as<ext_vector_t<e8m0_t, NumScales>>()[number<0>{}]);
|
||||
int64_t scale_b =
|
||||
bit_cast<int64_t>(sb_tile.get_thread_buffer()
|
||||
.template get_as<ext_vector_t<e8m0_t, NumScales>>()[number<0>{}]);
|
||||
|
||||
auto c_tile = WarpGemm{}(a_tile, b_tile, scale_a, scale_b);
|
||||
store_tile(c_win, c_tile);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename A, typename B, typename Acc, index_t M, index_t N, index_t K>
|
||||
struct WGDispCase
|
||||
{
|
||||
using AType = A;
|
||||
using BType = B;
|
||||
using AccType = Acc;
|
||||
static constexpr index_t MPerWave = M;
|
||||
static constexpr index_t NPerWave = N;
|
||||
static constexpr index_t KPerWave = K;
|
||||
};
|
||||
|
||||
template <typename Case, bool TransposeC>
|
||||
static void RunTest(const HostTensor<typename Case::AType>& A,
|
||||
const HostTensor<typename Case::BType>& B,
|
||||
const HostTensor<e8m0_t>& ScaleA,
|
||||
const HostTensor<e8m0_t>& ScaleB,
|
||||
HostTensor<typename Case::AccType>& C)
|
||||
{
|
||||
DeviceMem Ad(A), Bd(B), Cd(C), SAd(ScaleA), SBd(ScaleB);
|
||||
dim3 grid(1), block{32};
|
||||
|
||||
using K = WarpGemmScale16Kernel<typename Case::AType,
|
||||
typename Case::BType,
|
||||
typename Case::AccType,
|
||||
Case::MPerWave,
|
||||
Case::NPerWave,
|
||||
Case::KPerWave,
|
||||
TransposeC>;
|
||||
|
||||
(void)launch_kernel(stream_config{nullptr, true, 0, 0, 1},
|
||||
make_kernel(K{},
|
||||
grid,
|
||||
block,
|
||||
0,
|
||||
Ad.GetDeviceBuffer(),
|
||||
Bd.GetDeviceBuffer(),
|
||||
Cd.GetDeviceBuffer(),
|
||||
SAd.GetDeviceBuffer(),
|
||||
SBd.GetDeviceBuffer()));
|
||||
|
||||
Cd.FromDevice(C.mData.data());
|
||||
}
|
||||
|
||||
using WGDispatcherTypesList = ::testing::Types<WGDispCase<fp8_t, fp8_t, float, 16, 16, 128>,
|
||||
WGDispCase<bf8_t, bf8_t, float, 16, 16, 128>,
|
||||
WGDispCase<fp8_t, bf8_t, float, 16, 16, 128>,
|
||||
WGDispCase<bf8_t, fp8_t, float, 16, 16, 128>>;
|
||||
|
||||
template <typename Case>
|
||||
class WGScale16Test : public ::testing::Test
|
||||
{
|
||||
};
|
||||
|
||||
TYPED_TEST_SUITE(WGScale16Test, WGDispatcherTypesList);
|
||||
|
||||
TYPED_TEST(WGScale16Test, Scale16_16x16x128_UniformScale)
|
||||
{
|
||||
using Case = TypeParam;
|
||||
using AType = typename Case::AType;
|
||||
using BType = typename Case::BType;
|
||||
using CType = typename Case::AccType;
|
||||
|
||||
constexpr index_t M = Case::MPerWave;
|
||||
constexpr index_t N = Case::NPerWave;
|
||||
constexpr index_t K = Case::KPerWave;
|
||||
constexpr index_t NumScales = K / 16;
|
||||
|
||||
HostTensor<AType> A({M, K});
|
||||
HostTensor<BType> B({N, K});
|
||||
HostTensor<CType> C({M, N});
|
||||
HostTensor<e8m0_t> sA({M, NumScales});
|
||||
HostTensor<e8m0_t> sB({N, NumScales});
|
||||
|
||||
FillUniformDistribution<AType>{-5.f, 5.f}(A);
|
||||
FillUniformDistribution<BType>{-5.f, 5.f}(B);
|
||||
C.SetZero();
|
||||
FillConstant<e8m0_t>{e8m0_t{2.f}}(sA);
|
||||
FillConstant<e8m0_t>{e8m0_t{4.f}}(sB);
|
||||
|
||||
RunTest<Case, false>(A, B, sA, sB, C);
|
||||
|
||||
HostTensor<CType> C_ref({M, N});
|
||||
C_ref.SetZero();
|
||||
reference_mx_gemm<AType, BType, e8m0_t, e8m0_t, CType, CType>(
|
||||
A, B.transpose(), C_ref, sA, sB.transpose());
|
||||
|
||||
EXPECT_TRUE(check_err(C, C_ref, "Scale16 uniform scale error."));
|
||||
}
|
||||
|
||||
TYPED_TEST(WGScale16Test, Scale16_16x16x128_RandomDataRandomScales)
|
||||
{
|
||||
using Case = TypeParam;
|
||||
using AType = typename Case::AType;
|
||||
using BType = typename Case::BType;
|
||||
using CType = typename Case::AccType;
|
||||
|
||||
constexpr index_t M = Case::MPerWave;
|
||||
constexpr index_t N = Case::NPerWave;
|
||||
constexpr index_t K = Case::KPerWave;
|
||||
constexpr index_t NumScales = K / 16;
|
||||
|
||||
HostTensor<AType> A({M, K});
|
||||
HostTensor<BType> B({N, K});
|
||||
HostTensor<CType> C({M, N});
|
||||
HostTensor<e8m0_t> sA({M, NumScales});
|
||||
HostTensor<e8m0_t> sB({N, NumScales});
|
||||
|
||||
FillUniformDistribution<AType>{-5.f, 5.f, 42}(A);
|
||||
FillUniformDistribution<BType>{-5.f, 5.f, 137}(B);
|
||||
C.SetZero();
|
||||
|
||||
{
|
||||
constexpr int bias = ck_tile::numeric_traits<e8m0_t>::bias;
|
||||
std::mt19937 gen(9999);
|
||||
std::uniform_int_distribution<int> dist(bias - 4, bias + 2);
|
||||
for(auto& s : sA.mData)
|
||||
s = e8m0_t(static_cast<typename e8m0_t::type>(dist(gen)));
|
||||
for(auto& s : sB.mData)
|
||||
s = e8m0_t(static_cast<typename e8m0_t::type>(dist(gen)));
|
||||
}
|
||||
|
||||
RunTest<Case, false>(A, B, sA, sB, C);
|
||||
|
||||
HostTensor<CType> C_ref({M, N});
|
||||
C_ref.SetZero();
|
||||
reference_mx_gemm<AType, BType, e8m0_t, e8m0_t, CType, CType>(
|
||||
A, B.transpose(), C_ref, sA, sB.transpose());
|
||||
|
||||
const float max_acc = *std::max_element(C_ref.mData.begin(), C_ref.mData.end());
|
||||
const auto rtol = ck_tile::get_relative_threshold<AType, CType, CType>(K);
|
||||
const auto atol = ck_tile::get_absolute_threshold<AType, CType, CType>(max_acc, K);
|
||||
EXPECT_TRUE(check_err(C, C_ref, "Scale16 random data + random scales error.", rtol, atol));
|
||||
}
|
||||
|
||||
TYPED_TEST(WGScale16Test, Scale16_16x16x128_TransposeC)
|
||||
{
|
||||
using Case = TypeParam;
|
||||
using AType = typename Case::AType;
|
||||
using BType = typename Case::BType;
|
||||
using CType = typename Case::AccType;
|
||||
|
||||
constexpr index_t M = Case::MPerWave;
|
||||
constexpr index_t N = Case::NPerWave;
|
||||
constexpr index_t K = Case::KPerWave;
|
||||
constexpr index_t NumScales = K / 16;
|
||||
|
||||
HostTensor<AType> A({M, K});
|
||||
HostTensor<BType> B({N, K});
|
||||
HostTensor<CType> C({M, N});
|
||||
HostTensor<e8m0_t> sA({M, NumScales});
|
||||
HostTensor<e8m0_t> sB({N, NumScales});
|
||||
|
||||
FillUniformDistribution<AType>{-5.f, 5.f, 77}(A);
|
||||
FillUniformDistribution<BType>{-5.f, 5.f, 88}(B);
|
||||
C.SetZero();
|
||||
|
||||
{
|
||||
constexpr int bias = ck_tile::numeric_traits<e8m0_t>::bias;
|
||||
std::mt19937 gen(5555);
|
||||
std::uniform_int_distribution<int> dist(bias - 4, bias + 2);
|
||||
for(auto& s : sA.mData)
|
||||
s = e8m0_t(static_cast<typename e8m0_t::type>(dist(gen)));
|
||||
for(auto& s : sB.mData)
|
||||
s = e8m0_t(static_cast<typename e8m0_t::type>(dist(gen)));
|
||||
}
|
||||
|
||||
RunTest<Case, true>(A, B, sA, sB, C);
|
||||
|
||||
HostTensor<CType> C_ref({M, N});
|
||||
C_ref.SetZero();
|
||||
reference_mx_gemm<AType, BType, e8m0_t, e8m0_t, CType, CType>(
|
||||
A, B.transpose(), C_ref, sA, sB.transpose());
|
||||
|
||||
const float max_acc = *std::max_element(C_ref.mData.begin(), C_ref.mData.end());
|
||||
const auto rtol = ck_tile::get_relative_threshold<AType, CType, CType>(K);
|
||||
const auto atol = ck_tile::get_absolute_threshold<AType, CType, CType>(max_acc, K);
|
||||
EXPECT_TRUE(check_err(C, C_ref, "Scale16 TransposeC error.", rtol, atol));
|
||||
}
|
||||
362
test/ck_tile/warp_gemm/test_f32_32x32x128_fp8_scale16.cpp
Normal file
362
test/ck_tile/warp_gemm/test_f32_32x32x128_fp8_scale16.cpp
Normal file
@@ -0,0 +1,362 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
//
|
||||
// Unit test for a 32x32x128 block-scaled GEMM built from the 16x16x128 scale16
|
||||
// warp gemm (V_WMMA_SCALE16_F32_16X16X128_F8F6F4), modeled after the MX GEMM
|
||||
// pipelines (BlockGemmARegBRegCRegV1 / mx_flatmm pipeline): all A/B sub-tiles and
|
||||
// their per-block K-scales for the 2x2 block-level loop are preloaded into
|
||||
// registers in advance, then consumed by the warp-gemm loop. OpSelA/OpSelB are
|
||||
// threaded through the warp-gemm call as the pipeline does; for a 16x16 tile the
|
||||
// hardware A/B sub-block index is 0, so per-M-block / per-N-block scale selection
|
||||
// is realized by indexing the preloaded per-block scales (each int64 carries that
|
||||
// block's 8 e8m0 K-scales).
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
|
||||
using namespace ck_tile;
|
||||
|
||||
// Distribution that stages one scale row per lane: NumRows rows mapped across the
|
||||
// NumRows lanes of the wave (lane L holds row L's NumScales e8m0 K-scales = one int64).
|
||||
// For a 32-row tile this yields lanes 0..15 -> rows 0..15, lanes 16..31 -> rows 16..31,
|
||||
// which is exactly the layout the hardware selects between via SCL_OPSEL (OpSel).
|
||||
template <index_t NumRows, index_t NumScales>
|
||||
CK_TILE_DEVICE static constexpr auto MakeScaleDistribution()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<NumRows>, sequence<NumScales>>,
|
||||
tuple<sequence<1>>,
|
||||
tuple<sequence<0>>,
|
||||
sequence<2>,
|
||||
sequence<0>>{});
|
||||
}
|
||||
|
||||
// Pipeline-style kernel: preload all scales for the 2x2 block in advance, then
|
||||
// run a 2x2 block-level loop over the 16x16x128 scale16 warp gemm. Each (mIter,
|
||||
// nIter) sub-tile consumes its own independent per-M/per-N-block K-scales.
|
||||
template <typename AType,
|
||||
typename BType,
|
||||
typename CType,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
bool TransposeC>
|
||||
struct WarpGemmScale16BlockLoopKernel
|
||||
{
|
||||
static constexpr int kBlockSize = 32;
|
||||
static constexpr index_t ScaleBlockK = 16;
|
||||
static constexpr index_t NumScales = K / ScaleBlockK;
|
||||
static constexpr index_t MPerWarp = 16;
|
||||
static constexpr index_t NPerWarp = 16;
|
||||
static constexpr index_t MIter = M / MPerWarp;
|
||||
static constexpr index_t NIter = N / NPerWarp;
|
||||
|
||||
__device__ void operator()(void* A, void* B, void* C, void* ScaleA, void* ScaleB) const
|
||||
{
|
||||
const auto a_view =
|
||||
make_naive_tensor_view<address_space_enum::global>(static_cast<AType*>(A),
|
||||
make_tuple(M, K),
|
||||
make_tuple(K, number<1>{}),
|
||||
number<K>{},
|
||||
number<1>{});
|
||||
const auto b_view =
|
||||
make_naive_tensor_view<address_space_enum::global>(static_cast<BType*>(B),
|
||||
make_tuple(N, K),
|
||||
make_tuple(K, number<1>{}),
|
||||
number<K>{},
|
||||
number<1>{});
|
||||
const auto c_view =
|
||||
make_naive_tensor_view<address_space_enum::global>(static_cast<CType*>(C),
|
||||
make_tuple(M, N),
|
||||
make_tuple(N, number<1>{}),
|
||||
number<N>{},
|
||||
number<1>{});
|
||||
const auto sa_view =
|
||||
make_naive_tensor_view<address_space_enum::global>(static_cast<e8m0_t*>(ScaleA),
|
||||
make_tuple(M, NumScales),
|
||||
make_tuple(NumScales, number<1>{}),
|
||||
number<NumScales>{},
|
||||
number<1>{});
|
||||
const auto sb_view =
|
||||
make_naive_tensor_view<address_space_enum::global>(static_cast<e8m0_t*>(ScaleB),
|
||||
make_tuple(N, NumScales),
|
||||
make_tuple(NumScales, number<1>{}),
|
||||
number<NumScales>{},
|
||||
number<1>{});
|
||||
|
||||
using WarpGemm = WarpGemmDispatcher<
|
||||
AType, // ADataType: A element type (fp8_t / bf8_t)
|
||||
BType, // BDataType: B element type (fp8_t / bf8_t)
|
||||
float, // AccDataType: accumulator type (F32)
|
||||
MPerWarp, // MPerWave: warp-tile M (16 - native op size)
|
||||
NPerWarp, // NPerWave: warp-tile N (16 - native op size)
|
||||
K, // KPerWave: warp-tile K (128)
|
||||
TransposeC, // TransposeC: use transposed-C distribution
|
||||
false, // SwizzleA: A LDS swizzle layout (off)
|
||||
false, // UseStructuredSparsity: 2:4 sparsity (off)
|
||||
WGAttrNumAccessEnum::Default, // AttrNumAccessA: A num-access attribute
|
||||
WGAttrNumAccessEnum::Default, // AttrNumAccessB: B num-access attribute
|
||||
true>; // IsScale16: select the scale16 WMMA variant
|
||||
|
||||
constexpr auto a_dstr = typename WarpGemm::AWarpDstr{};
|
||||
constexpr auto b_dstr = typename WarpGemm::BWarpDstr{};
|
||||
constexpr auto c_dstr = typename WarpGemm::CWarpDstr{};
|
||||
constexpr auto scale_dstr = MakeScaleDistribution<M, NumScales>();
|
||||
|
||||
// ---- Preload phase ----
|
||||
// A/B sub-tiles for the 2x2 block loop. Scales are staged ONCE across all 32
|
||||
// lanes (not per-block arrays): a single int64 A-scale and B-scale carry every
|
||||
// row's/col's K-scales, and the per-block selection is done in hardware by
|
||||
// OpSelA/OpSelB (SCL_OPSEL) -- lanes 0..15 for block 0, lanes 16..31 for block 1.
|
||||
statically_indexed_array<decltype(load_tile(
|
||||
make_tile_window(a_view,
|
||||
make_tuple(number<MPerWarp>{}, number<K>{}),
|
||||
make_multi_index(0, 0),
|
||||
a_dstr))),
|
||||
MIter>
|
||||
a_tiles;
|
||||
static_for<0, MIter, 1>{}([&](auto mIter) {
|
||||
auto a_win = make_tile_window(a_view,
|
||||
make_tuple(number<MPerWarp>{}, number<K>{}),
|
||||
make_multi_index(mIter.value * MPerWarp, 0),
|
||||
a_dstr);
|
||||
a_tiles(mIter) = load_tile(a_win);
|
||||
});
|
||||
|
||||
statically_indexed_array<decltype(load_tile(
|
||||
make_tile_window(b_view,
|
||||
make_tuple(number<NPerWarp>{}, number<K>{}),
|
||||
make_multi_index(0, 0),
|
||||
b_dstr))),
|
||||
NIter>
|
||||
b_tiles;
|
||||
static_for<0, NIter, 1>{}([&](auto nIter) {
|
||||
auto b_win = make_tile_window(b_view,
|
||||
make_tuple(number<NPerWarp>{}, number<K>{}),
|
||||
make_multi_index(nIter.value * NPerWarp, 0),
|
||||
b_dstr);
|
||||
b_tiles(nIter) = load_tile(b_win);
|
||||
});
|
||||
|
||||
// Single 64-bit scale scalars: all M rows / N cols staged across the 32 lanes.
|
||||
auto sa_win = make_tile_window(sa_view,
|
||||
make_tuple(number<M>{}, number<NumScales>{}),
|
||||
make_multi_index(0, 0),
|
||||
scale_dstr);
|
||||
const int64_t scale_a =
|
||||
bit_cast<int64_t>(load_tile(sa_win)
|
||||
.get_thread_buffer()
|
||||
.template get_as<ext_vector_t<e8m0_t, NumScales>>()[number<0>{}]);
|
||||
auto sb_win = make_tile_window(sb_view,
|
||||
make_tuple(number<N>{}, number<NumScales>{}),
|
||||
make_multi_index(0, 0),
|
||||
scale_dstr);
|
||||
const int64_t scale_b =
|
||||
bit_cast<int64_t>(load_tile(sb_win)
|
||||
.get_thread_buffer()
|
||||
.template get_as<ext_vector_t<e8m0_t, NumScales>>()[number<0>{}]);
|
||||
|
||||
// ---- Compute phase ----
|
||||
// 2x2 block-level loop. The same int64 scale scalars are passed to every call;
|
||||
// the correct per-block K-scales are fetched in hardware via OpSelA/OpSelB, which
|
||||
// select lanes 0..15 (block 0) vs lanes 16..31 (block 1) of the scale register.
|
||||
static_for<0, MIter, 1>{}([&](auto mIter) {
|
||||
static_for<0, NIter, 1>{}([&](auto nIter) {
|
||||
auto c_win = make_tile_window(
|
||||
c_view,
|
||||
make_tuple(number<MPerWarp>{}, number<NPerWarp>{}),
|
||||
make_multi_index(mIter.value * MPerWarp, nIter.value * NPerWarp),
|
||||
c_dstr);
|
||||
|
||||
auto c_tile =
|
||||
WarpGemm{}.template operator()<OpSelA<mIter.value>, OpSelB<nIter.value>>(
|
||||
a_tiles(mIter), b_tiles(nIter), scale_a, scale_b);
|
||||
store_tile(c_win, c_tile);
|
||||
});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename A, typename B, typename Acc, index_t M, index_t N, index_t K>
|
||||
struct WGDispCase
|
||||
{
|
||||
using AType = A;
|
||||
using BType = B;
|
||||
using AccType = Acc;
|
||||
static constexpr index_t MPerWave = M;
|
||||
static constexpr index_t NPerWave = N;
|
||||
static constexpr index_t KPerWave = K;
|
||||
};
|
||||
|
||||
template <typename Case, bool TransposeC>
|
||||
static void RunTest(const HostTensor<typename Case::AType>& A,
|
||||
const HostTensor<typename Case::BType>& B,
|
||||
const HostTensor<e8m0_t>& ScaleA,
|
||||
const HostTensor<e8m0_t>& ScaleB,
|
||||
HostTensor<typename Case::AccType>& C)
|
||||
{
|
||||
DeviceMem Ad(A), Bd(B), Cd(C), SAd(ScaleA), SBd(ScaleB);
|
||||
dim3 grid(1), block{32};
|
||||
|
||||
using K = WarpGemmScale16BlockLoopKernel<typename Case::AType,
|
||||
typename Case::BType,
|
||||
typename Case::AccType,
|
||||
Case::MPerWave,
|
||||
Case::NPerWave,
|
||||
Case::KPerWave,
|
||||
TransposeC>;
|
||||
|
||||
(void)launch_kernel(stream_config{nullptr, true, 0, 0, 1},
|
||||
make_kernel(K{},
|
||||
grid,
|
||||
block,
|
||||
0,
|
||||
Ad.GetDeviceBuffer(),
|
||||
Bd.GetDeviceBuffer(),
|
||||
Cd.GetDeviceBuffer(),
|
||||
SAd.GetDeviceBuffer(),
|
||||
SBd.GetDeviceBuffer()));
|
||||
|
||||
Cd.FromDevice(C.mData.data());
|
||||
}
|
||||
|
||||
using WGDispatcherTypesList = ::testing::Types<WGDispCase<fp8_t, fp8_t, float, 32, 32, 128>,
|
||||
WGDispCase<bf8_t, bf8_t, float, 32, 32, 128>,
|
||||
WGDispCase<fp8_t, bf8_t, float, 32, 32, 128>,
|
||||
WGDispCase<bf8_t, fp8_t, float, 32, 32, 128>>;
|
||||
|
||||
template <typename Case>
|
||||
class WGScale16Test : public ::testing::Test
|
||||
{
|
||||
};
|
||||
|
||||
TYPED_TEST_SUITE(WGScale16Test, WGDispatcherTypesList);
|
||||
|
||||
TYPED_TEST(WGScale16Test, Scale16_32x32x128_UniformScale)
|
||||
{
|
||||
using Case = TypeParam;
|
||||
using AType = typename Case::AType;
|
||||
using BType = typename Case::BType;
|
||||
using CType = typename Case::AccType;
|
||||
|
||||
constexpr index_t M = Case::MPerWave;
|
||||
constexpr index_t N = Case::NPerWave;
|
||||
constexpr index_t K = Case::KPerWave;
|
||||
constexpr index_t NumScales = K / 16;
|
||||
|
||||
HostTensor<AType> A({M, K});
|
||||
HostTensor<BType> B({N, K});
|
||||
HostTensor<CType> C({M, N});
|
||||
HostTensor<e8m0_t> sA({M, NumScales});
|
||||
HostTensor<e8m0_t> sB({N, NumScales});
|
||||
|
||||
FillUniformDistribution<AType>{-5.f, 5.f}(A);
|
||||
FillUniformDistribution<BType>{-5.f, 5.f}(B);
|
||||
C.SetZero();
|
||||
FillConstant<e8m0_t>{e8m0_t{2.f}}(sA);
|
||||
FillConstant<e8m0_t>{e8m0_t{4.f}}(sB);
|
||||
|
||||
RunTest<Case, false>(A, B, sA, sB, C);
|
||||
|
||||
HostTensor<CType> C_ref({M, N});
|
||||
C_ref.SetZero();
|
||||
reference_mx_gemm<AType, BType, e8m0_t, e8m0_t, CType, CType>(
|
||||
A, B.transpose(), C_ref, sA, sB.transpose());
|
||||
|
||||
EXPECT_TRUE(check_err(C, C_ref, "Scale16 32x32x128 uniform scale error."));
|
||||
}
|
||||
|
||||
TYPED_TEST(WGScale16Test, Scale16_32x32x128_RandomDataRandomScales)
|
||||
{
|
||||
using Case = TypeParam;
|
||||
using AType = typename Case::AType;
|
||||
using BType = typename Case::BType;
|
||||
using CType = typename Case::AccType;
|
||||
|
||||
constexpr index_t M = Case::MPerWave;
|
||||
constexpr index_t N = Case::NPerWave;
|
||||
constexpr index_t K = Case::KPerWave;
|
||||
constexpr index_t NumScales = K / 16;
|
||||
|
||||
HostTensor<AType> A({M, K});
|
||||
HostTensor<BType> B({N, K});
|
||||
HostTensor<CType> C({M, N});
|
||||
HostTensor<e8m0_t> sA({M, NumScales});
|
||||
HostTensor<e8m0_t> sB({N, NumScales});
|
||||
|
||||
FillUniformDistribution<AType>{-5.f, 5.f, 42}(A);
|
||||
FillUniformDistribution<BType>{-5.f, 5.f, 137}(B);
|
||||
C.SetZero();
|
||||
|
||||
{
|
||||
constexpr int bias = ck_tile::numeric_traits<e8m0_t>::bias;
|
||||
std::mt19937 gen(9999);
|
||||
std::uniform_int_distribution<int> dist(bias - 4, bias + 2);
|
||||
for(auto& s : sA.mData)
|
||||
s = e8m0_t(static_cast<typename e8m0_t::type>(dist(gen)));
|
||||
for(auto& s : sB.mData)
|
||||
s = e8m0_t(static_cast<typename e8m0_t::type>(dist(gen)));
|
||||
}
|
||||
|
||||
RunTest<Case, false>(A, B, sA, sB, C);
|
||||
|
||||
HostTensor<CType> C_ref({M, N});
|
||||
C_ref.SetZero();
|
||||
reference_mx_gemm<AType, BType, e8m0_t, e8m0_t, CType, CType>(
|
||||
A, B.transpose(), C_ref, sA, sB.transpose());
|
||||
|
||||
const float max_acc = *std::max_element(C_ref.mData.begin(), C_ref.mData.end());
|
||||
const auto rtol = ck_tile::get_relative_threshold<AType, CType, CType>(K);
|
||||
const auto atol = ck_tile::get_absolute_threshold<AType, CType, CType>(max_acc, K);
|
||||
EXPECT_TRUE(
|
||||
check_err(C, C_ref, "Scale16 32x32x128 random data + random scales error.", rtol, atol));
|
||||
}
|
||||
|
||||
TYPED_TEST(WGScale16Test, Scale16_32x32x128_TransposeC)
|
||||
{
|
||||
using Case = TypeParam;
|
||||
using AType = typename Case::AType;
|
||||
using BType = typename Case::BType;
|
||||
using CType = typename Case::AccType;
|
||||
|
||||
constexpr index_t M = Case::MPerWave;
|
||||
constexpr index_t N = Case::NPerWave;
|
||||
constexpr index_t K = Case::KPerWave;
|
||||
constexpr index_t NumScales = K / 16;
|
||||
|
||||
HostTensor<AType> A({M, K});
|
||||
HostTensor<BType> B({N, K});
|
||||
HostTensor<CType> C({M, N});
|
||||
HostTensor<e8m0_t> sA({M, NumScales});
|
||||
HostTensor<e8m0_t> sB({N, NumScales});
|
||||
|
||||
FillUniformDistribution<AType>{-5.f, 5.f, 77}(A);
|
||||
FillUniformDistribution<BType>{-5.f, 5.f, 88}(B);
|
||||
C.SetZero();
|
||||
|
||||
{
|
||||
constexpr int bias = ck_tile::numeric_traits<e8m0_t>::bias;
|
||||
std::mt19937 gen(5555);
|
||||
std::uniform_int_distribution<int> dist(bias - 4, bias + 2);
|
||||
for(auto& s : sA.mData)
|
||||
s = e8m0_t(static_cast<typename e8m0_t::type>(dist(gen)));
|
||||
for(auto& s : sB.mData)
|
||||
s = e8m0_t(static_cast<typename e8m0_t::type>(dist(gen)));
|
||||
}
|
||||
|
||||
RunTest<Case, true>(A, B, sA, sB, C);
|
||||
|
||||
HostTensor<CType> C_ref({M, N});
|
||||
C_ref.SetZero();
|
||||
reference_mx_gemm<AType, BType, e8m0_t, e8m0_t, CType, CType>(
|
||||
A, B.transpose(), C_ref, sA, sB.transpose());
|
||||
|
||||
const float max_acc = *std::max_element(C_ref.mData.begin(), C_ref.mData.end());
|
||||
const auto rtol = ck_tile::get_relative_threshold<AType, CType, CType>(K);
|
||||
const auto atol = ck_tile::get_absolute_threshold<AType, CType, CType>(max_acc, K);
|
||||
EXPECT_TRUE(check_err(C, C_ref, "Scale16 32x32x128 TransposeC error.", rtol, atol));
|
||||
}
|
||||
Reference in New Issue
Block a user