[rocm-libraries] ROCm/rocm-libraries#7677 (commit 308af93)

[CK_Tile] Add scale16 Support for F4 WMMA in CK_Tile

## Motivation
This PR adds CK Tile support for the scale16 F4 WMMA path on gfx1250 and
improves warp GEMM unit test coverage/structure for gfx1250-specific
cases.

## Technical Details

- Scale16 support in warp GEMM dispatch and WMMA trait plumbing: added
IsScale16 plumbing to warp GEMM dispatcher path
- Warp GEMM test restructuring for gfx1250: added Warp GEMM gfx1250
coverage to verify all F4 WMMA paths

## Test Plan
Run ./test_ck_tile_wg_32x16x128_fp4.

## Test Result
```
./test_ck_tile_wg_32x16x128_fp4
[----------] Global test environment tear-down
[==========] 3 tests from 1 test suite ran. (1751 ms total)
[  PASSED  ] 3 tests.
```

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Tianyuan Wu
2026-05-30 01:28:48 +00:00
committed by assistant-librarian[bot]
parent 8d97265896
commit 22a99f97e8
11 changed files with 507 additions and 234 deletions

View File

@@ -103,7 +103,101 @@ struct Packed4Scale
}
};
template <typename ScaleType>
struct Packed8Scale
{
using scale_type = ScaleType;
using raw_type = uint64_t;
using raw_scale_type = typename ScaleType::raw_type;
static constexpr int num_pack = 8;
union
{
raw_type data_;
raw_scale_type scales_[num_pack]; // Direct byte/element access
};
// Constructors
CK_TILE_HOST_DEVICE constexpr Packed8Scale() = default;
CK_TILE_HOST_DEVICE constexpr Packed8Scale(raw_type val) : data_(val) {}
CK_TILE_HOST_DEVICE constexpr Packed8Scale(
float s0, float s1, float s2, float s3, float s4, float s5, float s6, float s7)
{
set_scales_from_float(s0, s1, s2, s3, s4, s5, s6, s7);
}
CK_TILE_HOST_DEVICE constexpr Packed8Scale(ScaleType s0,
ScaleType s1,
ScaleType s2,
ScaleType s3,
ScaleType s4,
ScaleType s5,
ScaleType s6,
ScaleType s7)
{
set_scales(s0, s1, s2, s3, s4, s5, s6, s7);
}
CK_TILE_HOST_DEVICE constexpr void set_scales_from_float(
float s0, float s1, float s2, float s3, float s4, float s5, float s6, float s7)
{
set_scales(ScaleType(s0),
ScaleType(s1),
ScaleType(s2),
ScaleType(s3),
ScaleType(s4),
ScaleType(s5),
ScaleType(s6),
ScaleType(s7));
}
CK_TILE_HOST_DEVICE constexpr void set_scales(ScaleType s0,
ScaleType s1,
ScaleType s2,
ScaleType s3,
ScaleType s4,
ScaleType s5,
ScaleType s6,
ScaleType s7)
{
data_ = 0;
pack_scale(s0, 7);
pack_scale(s1, 6);
pack_scale(s2, 5);
pack_scale(s3, 4);
pack_scale(s4, 3);
pack_scale(s5, 2);
pack_scale(s6, 1);
pack_scale(s7, 0);
}
CK_TILE_HOST_DEVICE constexpr operator raw_type() const { return data_; }
CK_TILE_HOST_DEVICE constexpr raw_type& data() [[clang::lifetimebound]] { return data_; }
CK_TILE_HOST_DEVICE constexpr raw_type data() const { return data_; }
CK_TILE_HOST_DEVICE constexpr float unpack_to_float(int i) const
{
return static_cast<float>(unpack_scale(i));
}
CK_TILE_HOST_DEVICE constexpr ScaleType unpack_scale(int i) const
{
return ScaleType(scales_[i]);
}
CK_TILE_HOST_DEVICE constexpr void pack_from_float(float scale, int i)
{
pack_scale(ScaleType(scale), i);
}
CK_TILE_HOST_DEVICE constexpr void pack_scale(ScaleType scale, int i)
{
scales_[i] = scale.get();
}
};
// Type alias for e8m0_t scales
using Packed4Scale_E8M0 = Packed4Scale<e8m0_t>;
using Packed8Scale_E8M0 = Packed8Scale<e8m0_t>;
} // namespace ck_tile

View File

@@ -234,12 +234,12 @@ struct WarpGemmAttributeWmma
}
// c_vec += a_vec * b_vec
template <typename... Params>
template <typename... Params, typename AScaleType, typename BScaleType>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const int32_t& a_scale,
const AScaleType& a_scale,
const BVecType& b_vec,
const int32_t& b_scale) const
const BScaleType& b_scale) const
{
if constexpr(kTransC)
{
@@ -253,11 +253,11 @@ struct WarpGemmAttributeWmma
}
// c_vec = a_vec * b_vec
template <typename... Params>
template <typename... Params, typename AScaleType, typename BScaleType>
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec,
const int32_t& a_scale,
const AScaleType& a_scale,
const BVecType& b_vec,
const int32_t& b_scale) const
const BScaleType& b_scale) const
{
if constexpr(kTransC)
{

View File

@@ -19,6 +19,11 @@ template <typename Arch,
typename MXTypeEnable = void>
struct WmmaTraits;
// Tag used to select scale16 WMMA traits specializations.
struct WmmaScale16Tag
{
};
// Generic WMMA implementation using traits
template <typename Traits>
struct WarpGemmAttributeWmmaImpl
@@ -88,22 +93,22 @@ struct WarpGemmAttributeWmmaImpl
Traits::template wmma_intrinsic<Params...>(a_vec, b_vec, CVecType{0.f}));
}
template <typename... Params>
template <typename... Params, typename AScaleType, typename BScaleType>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const int32_t& a_scale,
const AScaleType& a_scale,
const BVecType& b_vec,
const int32_t& b_scale) const
const BScaleType& b_scale) const
{
c_vec = Traits::template wmma_intrinsic<Params...>(a_vec, a_scale, b_vec, b_scale, c_vec);
}
// c_vec = a_vec * b_vec
template <typename... Params>
template <typename... Params, typename AScaleType, typename BScaleType>
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec,
const int32_t& a_scale,
const AScaleType& a_scale,
const BVecType& b_vec,
const int32_t& b_scale) const
const BScaleType& b_scale) const
{
return bit_cast<CVecType>(Traits::template wmma_intrinsic<Params...>(
a_vec, a_scale, b_vec, b_scale, CVecType{0.f}));
@@ -177,6 +182,9 @@ using WarpGemmAttributeWmmaImpl_f32_32x16x128_f4 =
using WarpGemmAttributeWmmaImpl_f32_32x32x128_f4 =
WarpGemmAttributeWmmaImpl<WmmaTraits<gfx125_t, pk_fp4_t, pk_fp4_t, float, 32, 32, 128>>;
using WarpGemmAttributeWmmaImpl_f32_32x32x128_f4_scale16 = WarpGemmAttributeWmmaImpl<
WmmaTraits<gfx125_t, pk_fp4_t, pk_fp4_t, float, 32, 32, 128, WmmaScale16Tag>>;
using WarpGemmAttributeWmmaImpl_f16_16x16x64_f8_f8 =
WarpGemmAttributeWmmaImpl<WmmaTraits<gfx125_t, fp8_t, fp8_t, fp16_t, 16, 16, 64>>;

View File

@@ -6,6 +6,9 @@
#include "warp_gemm_attribute_wmma_impl_base_traits.hpp"
#include "warp_gemm_params.hpp"
namespace ck_tile {
struct WmmaScale16Tag;
// int8 specialization - GFX11
template <>
struct WmmaTraits<gfx11_t, int8_t, int8_t, int32_t, 16, 16, 16>
@@ -528,17 +531,18 @@ struct WmmaTraits<gfx125_t, pk_fp4_t, pk_fp4_t, float, 32, 16, 128>
}
};
template <>
struct WmmaTraits<gfx125_t, pk_fp4_t, pk_fp4_t, float, 32, 32, 128>
template <bool IsScale16>
struct WmmaTraitsGfx125PkFp4F32_32x32x128
: WmmaTraitsBase<gfx12_t, pk_fp4_t, pk_fp4_t, float, 128, false, 32, 32>
{
using ArchType = gfx125_t;
using ArchType = gfx125_t;
using ScaleType = std::conditional_t<IsScale16, int64_t, int32_t>;
template <typename... Params>
CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType& a_vec,
const int32_t& a_scale,
const ScaleType& a_scale,
const BVecType& b_vec,
const int32_t& b_scale,
const ScaleType& b_scale,
const CVecType& c_vec)
{
#ifdef __gfx125__
@@ -569,19 +573,38 @@ struct WmmaTraits<gfx125_t, pk_fp4_t, pk_fp4_t, float, 32, 32, 128>
const auto& b_slice = b_buffer.template get_as<BSliceType>()[n];
auto& c_slice = c_result.template get_as<CSliceType>()[n];
c_slice = __builtin_amdgcn_wmma_scale_f32_32x16x128_f4(
bit_cast<int32x16_t>(a_slice),
bit_cast<int32x8_t>(b_slice),
0,
c_slice,
1, // OPSEL[0] - fixed to 1 for F4
P::scale_a, // OPSEL_HI[0] - scale data type for A
a_scale,
n.value, // OPSEL[1] - select B scale (iterates over N blocks)
P::scale_b, // OPSEL_HI[1] - scale data type for B
b_scale,
0, // NEG
0); // NEG_HI
if constexpr(IsScale16)
{
c_slice = __builtin_amdgcn_wmma_scale16_f32_32x16x128_f4(
bit_cast<int32x16_t>(a_slice),
bit_cast<int32x8_t>(b_slice),
0,
c_slice,
1, // OPSEL[0] - fixed to 1 for F4
P::scale_a, // OPSEL_HI[0] - scale data type for A
a_scale,
n.value, // OPSEL[1] - select B scale (iterates over N blocks)
P::scale_b, // OPSEL_HI[1] - scale data type for B
b_scale,
0, // NEG
0); // NEG_HI
}
else
{
c_slice = __builtin_amdgcn_wmma_scale_f32_32x16x128_f4(
bit_cast<int32x16_t>(a_slice),
bit_cast<int32x8_t>(b_slice),
0,
c_slice,
1, // OPSEL[0] - fixed to 1 for F4
P::scale_a, // OPSEL_HI[0] - scale data type for A
a_scale,
n.value, // OPSEL[1] - select B scale (iterates over N blocks)
P::scale_b, // OPSEL_HI[1] - scale data type for B
b_scale,
0, // NEG
0); // NEG_HI
}
});
return bit_cast<CVecType>(c_result);
@@ -602,7 +625,8 @@ struct WmmaTraits<gfx125_t, pk_fp4_t, pk_fp4_t, float, 32, 32, 128>
#ifdef __gfx125__
// Pass default scale values 1.0f
Packed4Scale_E8M0 pkscale(1.0f, 1.0f, 1.0f, 1.0f);
return wmma_intrinsic(a_vec, pkscale, b_vec, pkscale, c_vec);
const auto default_scale = static_cast<ScaleType>(pkscale);
return wmma_intrinsic(a_vec, default_scale, b_vec, default_scale, c_vec);
#else
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
@@ -612,6 +636,18 @@ struct WmmaTraits<gfx125_t, pk_fp4_t, pk_fp4_t, float, 32, 32, 128>
}
};
template <>
struct WmmaTraits<gfx125_t, pk_fp4_t, pk_fp4_t, float, 32, 32, 128>
: WmmaTraitsGfx125PkFp4F32_32x32x128<false>
{
};
template <>
struct WmmaTraits<gfx125_t, pk_fp4_t, pk_fp4_t, float, 32, 32, 128, WmmaScale16Tag>
: WmmaTraitsGfx125PkFp4F32_32x32x128<true>
{
};
// f8f6f4 specialization - GFX125
enum F8F6F4OpDataTypeEnum
{

View File

@@ -33,6 +33,7 @@ template <typename AType,
bool UseStructuredSparsity = false,
WGAttrNumAccessEnum AttrNumAccessA = ESingle,
WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA,
bool IsScale16 = false,
typename Enable = void>
struct Dispatcher;
@@ -178,10 +179,10 @@ template<> struct Dispatcher<bf8_t, bf8_t, float, 32, 32, 16, true> { using Ty
#if !defined(__gfx125__)
// scale mfma based f8f6f4
template<typename A, typename B, WGAttrNumAccessEnum I>
struct Dispatcher<A, B, float, 16, 16, 128, false, false, false, I, I, std::enable_if_t<I != EDefault>> { using Type = WarpGemmMfma_f32_16x16x128_f8f6f4<A, B, I>; };
template<typename A, typename B, WGAttrNumAccessEnum I>
struct Dispatcher<A, B, float, 16, 16, 128, true, false, false, I, I, std::enable_if_t<I != EDefault>> { using Type = WarpGemmMfma_f32_16x16x128_f8f6f4_CTransposed<A, B, I>; };
template<typename A, typename B, WGAttrNumAccessEnum I, bool IsScale16>
struct Dispatcher<A, B, float, 16, 16, 128, false, false, false, I, I, IsScale16, std::enable_if_t<I != EDefault>> { using Type = WarpGemmMfma_f32_16x16x128_f8f6f4<A, B, I>; };
template<typename A, typename B, WGAttrNumAccessEnum I, bool IsScale16>
struct Dispatcher<A, B, float, 16, 16, 128, true, false, false, I, I, IsScale16, std::enable_if_t<I != EDefault>> { using Type = WarpGemmMfma_f32_16x16x128_f8f6f4_CTransposed<A, B, I>; };
#endif
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8<>; };
@@ -224,7 +225,7 @@ template<bool TransposeC, WGAttrNumAccessEnum AttrNumAccess> struct Dispatcher<f
template<bool TransposeC, WGAttrNumAccessEnum AttrNumAccess> struct Dispatcher<bf8_t, fp8_t, float, 16, 16, 64, TransposeC, false, false, AttrNumAccess, AttrNumAccess> : WmmaTag { using Type = WarpGemmWmma_f32_16x16x64_bf8_f8<TransposeC, AttrNumAccess>; };
template<bool TransposeC, WGAttrNumAccessEnum AttrNumAccess> struct Dispatcher<pk_fp4_t, pk_fp4_t, float, 32, 16, 128, TransposeC, false, false, AttrNumAccess, AttrNumAccess> : WmmaTag { using Type = WarpGemmWmma_f32_32x16x128_f4<TransposeC, AttrNumAccess>; };
template<bool TransposeC, WGAttrNumAccessEnum AttrNumAccess> struct Dispatcher<pk_fp4_t, pk_fp4_t, float, 32, 32, 128, TransposeC, false, false, AttrNumAccess, AttrNumAccess> : WmmaTag { using Type = WarpGemmWmma_f32_32x32x128_f4<TransposeC, AttrNumAccess>; };
template<bool TransposeC, WGAttrNumAccessEnum AttrNumAccess, bool IsScale16> struct Dispatcher<pk_fp4_t, pk_fp4_t, float, 32, 32, 128, TransposeC, false, false, AttrNumAccess, AttrNumAccess, IsScale16> : WmmaTag { using Type = WarpGemmWmma_f32_32x32x128_f4<TransposeC, AttrNumAccess, IsScale16>; };
#if defined(__gfx125__)
template<bool TransposeC, WGAttrNumAccessEnum AttrNumAccess> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 64, TransposeC, false, false, AttrNumAccess, AttrNumAccess> : WmmaTag { using Type = WarpGemmWmma_f32_16x16x64_f8_f8<TransposeC, AttrNumAccess>; };
@@ -244,8 +245,27 @@ template<> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 64, true> { using Typ
template<> struct Dispatcher<bf8_t, bf8_t, float, 16, 16, 64, true> { using Type = WarpGemmMfma_f32_16x16x64_bf8_bf8_CTransposed; };
#endif
template<typename A, typename B, bool TransposeC, WGAttrNumAccessEnum AttrNumAccessA, WGAttrNumAccessEnum AttrNumAccessB>
struct Dispatcher<A, B, float, 32, 32, 128, TransposeC, false, false, AttrNumAccessA, AttrNumAccessB> : WmmaTag { using Type = WarpGemmWmma_f32_32x32x128_f8f6f4<A, B, TransposeC, AttrNumAccessA, AttrNumAccessB>; };
template<typename A,
typename B,
bool TransposeC,
WGAttrNumAccessEnum AttrNumAccessA,
WGAttrNumAccessEnum AttrNumAccessB,
bool IsScale16>
struct Dispatcher<A,
B,
float,
32,
32,
128,
TransposeC,
false,
false,
AttrNumAccessA,
AttrNumAccessB,
IsScale16> : WmmaTag
{
using Type = WarpGemmWmma_f32_32x32x128_f8f6f4<A, B, TransposeC, AttrNumAccessA, AttrNumAccessB>;
};
template<bool TransposeC, WGAttrNumAccessEnum AttrNumAccess> struct Dispatcher<fp8_t, fp8_t, half_t, 16, 16, 64, TransposeC, false, false, AttrNumAccess, AttrNumAccess> : WmmaTag { using Type =WarpGemmWmma_f16_16x16x64_f8_f8<TransposeC, AttrNumAccess>; };
template<bool TransposeC, WGAttrNumAccessEnum AttrNumAccess> struct Dispatcher<bf8_t, bf8_t, half_t, 16, 16, 64, TransposeC, false, false, AttrNumAccess, AttrNumAccess> : WmmaTag { using Type =WarpGemmWmma_f16_16x16x64_bf8_bf8<TransposeC, AttrNumAccess>; };
@@ -265,12 +285,12 @@ template<bool TransposeC, WGAttrNumAccessEnum AttrNumAccess> struct Dispatcher<u
template <typename AType, typename BType, typename AccType,
index_t M, index_t N, index_t K,
bool TransposeC, bool SA, bool SS>
bool TransposeC, bool SA, bool SS, bool IsScale16>
struct Dispatcher<AType, BType, AccType, M, N, K, TransposeC, SA, SS,
EDefault, EDefault,
EDefault, EDefault, IsScale16,
std::enable_if_t<!std::is_base_of_v<WmmaTag,
Dispatcher<AType, BType, AccType, M, N, K, TransposeC, SA, SS, ESingle, ESingle, void>>>>
: Dispatcher<AType, BType, AccType, M, N, K, TransposeC, SA, SS, ESingle, ESingle, void> {};
Dispatcher<AType, BType, AccType, M, N, K, TransposeC, SA, SS, ESingle, ESingle, IsScale16, void>>>>
: Dispatcher<AType, BType, AccType, M, N, K, TransposeC, SA, SS, ESingle, ESingle, IsScale16, void> {};
// clang-format on
} // namespace warp_gemm_dispatcher
@@ -286,7 +306,8 @@ template <typename AType,
bool SwizzleA = false,
bool UseStructuredSparsity = false,
WGAttrNumAccessEnum AttrNumAccessA = WGAttrNumAccessEnum::Default,
WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA>
WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA,
bool IsScale16 = false>
using WarpGemmDispatcher = typename impl::warp_gemm_dispatcher::Dispatcher< //
AType,
BType,
@@ -298,6 +319,7 @@ using WarpGemmDispatcher = typename impl::warp_gemm_dispatcher::Dispatcher< //
SwizzleA,
UseStructuredSparsity,
AttrNumAccessA,
AttrNumAccessB>::Type;
AttrNumAccessB,
IsScale16>::Type;
} // namespace ck_tile

View File

@@ -90,12 +90,17 @@ struct WarpGemmImpl
c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
}
template <typename... Params, typename CTensor, typename ATensor, typename BTensor>
template <typename... Params,
typename CTensor,
typename ATensor,
typename BTensor,
typename AScaleType,
typename BScaleType>
CK_TILE_DEVICE void operator()(CTensor& c,
const ATensor& a,
const BTensor& b,
const int32_t& a_scale,
const int32_t& b_scale) const
const AScaleType& a_scale,
const BScaleType& b_scale) const
{
static_assert(detail::is_similiar_distributed_tensor_v<CTensor, CWarpTensor> &&
detail::is_similiar_distributed_tensor_v<ATensor, AWarpTensor> &&
@@ -141,11 +146,15 @@ struct WarpGemmImpl
return c;
}
template <typename... Params, typename ATensor, typename BTensor>
template <typename... Params,
typename ATensor,
typename BTensor,
typename AScaleType,
typename BScaleType>
CK_TILE_DEVICE auto operator()(const ATensor& a,
const BTensor& b,
const int32_t& a_scale,
const int32_t& b_scale) const
const AScaleType& a_scale,
const BScaleType& b_scale) const
{
using CTensor = CWarpTensor;
static_assert(detail::is_similiar_distributed_tensor_v<ATensor, AWarpTensor> &&

View File

@@ -187,12 +187,16 @@ using WarpGemmWmma_f32_32x16x128_f4 =
AttrNumAccess,
AttrNumAccess>>;
template <bool kTransC = false, WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Default>
using WarpGemmWmma_f32_32x32x128_f4 =
WarpGemmImpl<WarpGemmAttributeWmma<WarpGemmAttributeWmmaImpl_f32_32x32x128_f4,
kTransC,
AttrNumAccess,
AttrNumAccess>>;
template <bool kTransC = false,
WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Default,
bool IsScale16 = false>
using WarpGemmWmma_f32_32x32x128_f4 = WarpGemmImpl<
WarpGemmAttributeWmma<std::conditional_t<IsScale16,
WarpGemmAttributeWmmaImpl_f32_32x32x128_f4_scale16,
WarpGemmAttributeWmmaImpl_f32_32x32x128_f4>,
kTransC,
AttrNumAccess,
AttrNumAccess>>;
template <typename AType,
typename BType,

View File

@@ -4,3 +4,7 @@
if(GPU_TARGETS MATCHES "gfx95")
add_gtest_executable(test_ck_tile_wg_16x16x128_fp4 test_f32_16x16x128_fp4.cpp)
endif()
if(GPU_TARGETS MATCHES "gfx125")
add_gtest_executable(test_ck_tile_wg_32x16x128_fp4 test_f32_32x16x128_fp4.cpp)
endif()

View File

@@ -1,153 +1,19 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <gtest/gtest.h>
#include "ck_tile/host.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
using namespace ck_tile;
template <typename A,
typename B,
typename Acc,
index_t M,
index_t N,
index_t K,
bool TransposeC,
bool SwizzleA = false,
bool UseStructuredSparsity = false,
WGAttrNumAccessEnum NA = WGAttrNumAccessEnum::Single>
struct WGDispCase
{
using AType = A;
using BType = B;
using AccType = Acc;
static constexpr index_t MPerWave = M;
static constexpr index_t NPerWave = N;
static constexpr index_t KPerWave = K;
static constexpr bool kTransposeC = TransposeC;
static constexpr bool kSwizzleA = SwizzleA;
static constexpr bool kUSS = UseStructuredSparsity;
static constexpr WGAttrNumAccessEnum kNA = NA;
};
#include "test_gemm_util.hpp"
#include "gtest/gtest.h"
using WGDispatcherTypesList =
::testing::Types<WGDispCase<ck_tile::pk_fp4_t, ck_tile::pk_fp4_t, float, 16, 16, 128, false>>;
::testing::Types<ck_tile::test::warp_gemm::WGDispCase<ck_tile::pk_fp4_t,
ck_tile::pk_fp4_t,
float,
false,
false,
false,
ck_tile::WGAttrNumAccessEnum::Single>>;
template <typename AType,
typename BType,
typename CType,
index_t M,
index_t N,
index_t K,
bool TransposeC,
bool SwizzleA,
bool UseStructuredSparsity,
WGAttrNumAccessEnum NumAccess>
struct WarpGemmKernel
{
static constexpr int kBlockSize = 64;
__device__ void operator()(void* A, void* B, void* C, void* ScaleA, void* ScaleB) const
{
using WarpGemm = ck_tile::WarpGemmDispatcher<AType,
BType,
CType,
M,
N,
K,
TransposeC,
SwizzleA,
UseStructuredSparsity,
NumAccess>;
// A: [M,K] row-major (packed)
const auto a_view = ck_tile::make_naive_tensor_view<ck_tile::address_space_enum::global>(
static_cast<AType*>(A),
ck_tile::make_tuple(M, K),
ck_tile::make_tuple(K, ck_tile::number<1>{}),
ck_tile::number<K>{},
ck_tile::number<1>{});
// B: expose as logical [N,K] with strides (1, N) over the original row-major [K,N] buffer
const auto b_view = ck_tile::make_naive_tensor_view<ck_tile::address_space_enum::global>(
static_cast<BType*>(B),
ck_tile::make_tuple(N, K),
ck_tile::make_tuple(K, ck_tile::number<1>{}),
ck_tile::number<K>{},
ck_tile::number<1>{});
// C: [M,N] row-major (packed)
const auto c_view = ck_tile::make_naive_tensor_view<ck_tile::address_space_enum::global>(
static_cast<CType*>(C),
ck_tile::make_tuple(M, N),
ck_tile::make_tuple(N, ck_tile::number<1>{}),
ck_tile::number<N>{},
ck_tile::number<1>{});
using AWarpTensor = typename WarpGemm::AWarpTensor;
using BWarpTensor = typename WarpGemm::BWarpTensor;
using CWarpTensor = typename WarpGemm::CWarpTensor;
constexpr auto a_len = AWarpTensor::get_tile_distribution().get_lengths();
constexpr auto b_len = BWarpTensor::get_tile_distribution().get_lengths();
constexpr auto c_len = CWarpTensor::get_tile_distribution().get_lengths();
auto a_win = ck_tile::make_tile_window(
a_view, a_len, ck_tile::make_multi_index(0, 0), AWarpTensor::get_tile_distribution());
auto b_win = ck_tile::make_tile_window(
b_view, b_len, ck_tile::make_multi_index(0, 0), BWarpTensor::get_tile_distribution());
auto c_win = ck_tile::make_tile_window(
c_view, c_len, ck_tile::make_multi_index(0, 0), CWarpTensor::get_tile_distribution());
AWarpTensor a_tile;
BWarpTensor b_tile;
ck_tile::load_tile(a_tile, a_win);
ck_tile::load_tile(b_tile, b_win);
auto scale_a = static_cast<int32_t>(static_cast<ck_tile::e8m0_t*>(ScaleA)[0].get());
auto scale_b = static_cast<int32_t>(static_cast<ck_tile::e8m0_t*>(ScaleB)[0].get());
auto c_tile =
WarpGemm{}.template operator()<OpSelA<0>, OpSelB<0>>(a_tile, b_tile, scale_a, scale_b);
ck_tile::store_tile(c_win, c_tile);
}
};
template <typename Case>
static void RunWarpGemmCase(const ck_tile::HostTensor<typename Case::AType>& A,
const ck_tile::HostTensor<typename Case::BType>& B,
const ck_tile::HostTensor<e8m0_t>& ScaleA,
const ck_tile::HostTensor<e8m0_t>& ScaleB,
ck_tile::HostTensor<typename Case::AccType>& C)
{
ck_tile::DeviceMem Ad(A), Bd(B), Cd(C), SAd(ScaleA), SBd(ScaleB);
dim3 grid(1), block{64};
using Kernel = WarpGemmKernel<typename Case::AType,
typename Case::BType,
typename Case::AccType,
Case::MPerWave,
Case::NPerWave,
Case::KPerWave,
Case::kTransposeC,
Case::kSwizzleA,
Case::kUSS,
Case::kNA>;
(void)ck_tile::launch_kernel(ck_tile::stream_config{nullptr, true, 0, 0, 1},
ck_tile::make_kernel(Kernel{},
grid,
block,
0,
Ad.GetDeviceBuffer(),
Bd.GetDeviceBuffer(),
Cd.GetDeviceBuffer(),
SAd.GetDeviceBuffer(),
SBd.GetDeviceBuffer()));
Cd.FromDevice(C.mData.data());
}
template <typename Case>
template <typename T>
class WGRuntimeTest : public ::testing::Test
{
};
@@ -156,38 +22,6 @@ TYPED_TEST_SUITE(WGRuntimeTest, WGDispatcherTypesList);
TYPED_TEST(WGRuntimeTest, Compare_Dispatcher_MakeWG)
{
using Case = TypeParam;
using AType = typename Case::AType;
using BType = typename Case::BType;
using CType = typename Case::AccType;
using ck_tile::e8m0_t;
constexpr index_t M = Case::MPerWave;
constexpr index_t N = Case::NPerWave;
constexpr index_t K = Case::KPerWave;
auto ScaleA = e8m0_t{2.f};
auto ScaleB = e8m0_t{4.f};
ck_tile::HostTensor<AType> A({M, K});
ck_tile::HostTensor<BType> B({N, K});
ck_tile::HostTensor<CType> C({M, N});
ck_tile::HostTensor<e8m0_t> sA({M, 1});
ck_tile::HostTensor<e8m0_t> sB({N, 1});
ck_tile::FillUniformDistribution<AType>{-5.f, 5.f}(A);
ck_tile::FillUniformDistribution<BType>{-5.f, 5.f}(B);
C.SetZero();
ck_tile::FillConstant<e8m0_t>{ScaleA}(sA);
ck_tile::FillConstant<e8m0_t>{ScaleB}(sB);
RunWarpGemmCase<Case>(A, B, sA, sB, C);
ck_tile::HostTensor<CType> C_ref({M, N});
C_ref.SetZero();
ck_tile::reference_mx_gemm<AType, BType, e8m0_t, e8m0_t, CType, CType>(
A, B.transpose(), C_ref, sA, sB.transpose());
EXPECT_TRUE(ck_tile::check_err(C, C_ref, "Warp gemm result error."));
ck_tile::test::warp_gemm::
RunCompareDispatcherAndReference<TypeParam, 16, 16, 128, true, false>();
}

View File

@@ -0,0 +1,39 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "test_gemm_util.hpp"
#include "gtest/gtest.h"
using WGDispatcherTypesList =
::testing::Types<ck_tile::test::warp_gemm::WGDispCase<ck_tile::pk_fp4_t,
ck_tile::pk_fp4_t,
float,
false,
false,
false,
ck_tile::WGAttrNumAccessEnum::Single>>;
template <typename T>
class WGRuntimeTest : public ::testing::Test
{
};
TYPED_TEST_SUITE(WGRuntimeTest, WGDispatcherTypesList);
TYPED_TEST(WGRuntimeTest, Compare_Dispatcher_MakeWG_NonScaled)
{
ck_tile::test::warp_gemm::
RunCompareDispatcherAndReference<TypeParam, 32, 16, 128, false, false>();
}
TYPED_TEST(WGRuntimeTest, Compare_Dispatcher_MakeWG_Scale16)
{
ck_tile::test::warp_gemm::
RunCompareDispatcherAndReference<TypeParam, 32, 32, 128, true, true>();
}
TYPED_TEST(WGRuntimeTest, Compare_Dispatcher_MakeWG_Scale32)
{
ck_tile::test::warp_gemm::
RunCompareDispatcherAndReference<TypeParam, 32, 32, 128, true, false>();
}

View File

@@ -0,0 +1,223 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <gtest/gtest.h>
#include "ck_tile/host.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/core/numeric/mxfp_scale.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
namespace ck_tile::test::warp_gemm {
template <typename A,
typename B,
typename Acc,
bool TransposeC,
bool SwizzleA = false,
bool UseStructuredSparsity = false,
WGAttrNumAccessEnum NA = WGAttrNumAccessEnum::Single>
struct WGDispCase
{
using AType = A;
using BType = B;
using AccType = Acc;
static constexpr bool kTransposeC = TransposeC;
static constexpr bool kSwizzleA = SwizzleA;
static constexpr bool kUSS = UseStructuredSparsity;
static constexpr WGAttrNumAccessEnum kNA = NA;
};
template <typename Case,
index_t MPerWave,
index_t NPerWave,
index_t KPerWave,
bool UseScale = true,
bool IsScale16 = false>
struct WarpGemmKernel
{
static constexpr int kBlockSize = 64;
__device__ void operator()(void* A, void* B, void* C, void* ScaleA, void* ScaleB) const
{
using WarpGemm = ck_tile::WarpGemmDispatcher<typename Case::AType,
typename Case::BType,
typename Case::AccType,
MPerWave,
NPerWave,
KPerWave,
Case::kTransposeC,
Case::kSwizzleA,
Case::kUSS,
Case::kNA,
Case::kNA,
IsScale16>;
const auto a_view = ck_tile::make_naive_tensor_view<ck_tile::address_space_enum::global>(
static_cast<typename Case::AType*>(A),
ck_tile::make_tuple(MPerWave, KPerWave),
ck_tile::make_tuple(KPerWave, ck_tile::number<1>{}),
ck_tile::number<KPerWave>{},
ck_tile::number<1>{});
const auto b_view = ck_tile::make_naive_tensor_view<ck_tile::address_space_enum::global>(
static_cast<typename Case::BType*>(B),
ck_tile::make_tuple(NPerWave, KPerWave),
ck_tile::make_tuple(KPerWave, ck_tile::number<1>{}),
ck_tile::number<KPerWave>{},
ck_tile::number<1>{});
const auto c_view = ck_tile::make_naive_tensor_view<ck_tile::address_space_enum::global>(
static_cast<typename Case::AccType*>(C),
ck_tile::make_tuple(MPerWave, NPerWave),
ck_tile::make_tuple(NPerWave, ck_tile::number<1>{}),
ck_tile::number<NPerWave>{},
ck_tile::number<1>{});
using AWarpTensor = typename WarpGemm::AWarpTensor;
using BWarpTensor = typename WarpGemm::BWarpTensor;
using CWarpTensor = typename WarpGemm::CWarpTensor;
constexpr auto a_len = AWarpTensor::get_tile_distribution().get_lengths();
constexpr auto b_len = BWarpTensor::get_tile_distribution().get_lengths();
constexpr auto c_len = CWarpTensor::get_tile_distribution().get_lengths();
auto a_win = ck_tile::make_tile_window(
a_view, a_len, ck_tile::make_multi_index(0, 0), AWarpTensor::get_tile_distribution());
auto b_win = ck_tile::make_tile_window(
b_view, b_len, ck_tile::make_multi_index(0, 0), BWarpTensor::get_tile_distribution());
auto c_win = ck_tile::make_tile_window(
c_view, c_len, ck_tile::make_multi_index(0, 0), CWarpTensor::get_tile_distribution());
AWarpTensor a_tile;
BWarpTensor b_tile;
ck_tile::load_tile(a_tile, a_win);
ck_tile::load_tile(b_tile, b_win);
const auto c_tile = [&]() {
if constexpr(UseScale)
{
using ScaleType = std::conditional_t<IsScale16, int64_t, int32_t>;
const auto scale_a = static_cast<ck_tile::e8m0_t*>(ScaleA)[0];
const auto scale_b = static_cast<ck_tile::e8m0_t*>(ScaleB)[0];
const auto packed_scale_a = [&]() -> ScaleType {
if constexpr(IsScale16)
{
Packed8Scale_E8M0 pkscale(
scale_a, scale_a, scale_a, scale_a, scale_a, scale_a, scale_a, scale_a);
return static_cast<ScaleType>(pkscale);
}
else
{
Packed4Scale_E8M0 pkscale(scale_a, scale_a, scale_a, scale_a);
return static_cast<ScaleType>(pkscale);
}
}();
const auto packed_scale_b = [&]() -> ScaleType {
if constexpr(IsScale16)
{
Packed8Scale_E8M0 pkscale(
scale_b, scale_b, scale_b, scale_b, scale_b, scale_b, scale_b, scale_b);
return static_cast<ScaleType>(pkscale);
}
else
{
Packed4Scale_E8M0 pkscale(scale_b, scale_b, scale_b, scale_b);
return static_cast<ScaleType>(pkscale);
}
}();
return WarpGemm{}.template operator()<OpSelA<0>, OpSelB<0>>(
a_tile, b_tile, packed_scale_a, packed_scale_b);
}
else
{
ck_tile::ignore = ScaleA;
ck_tile::ignore = ScaleB;
return WarpGemm{}.template operator()<OpSelA<0>, OpSelB<0>>(a_tile, b_tile);
}
}();
ck_tile::store_tile(c_win, c_tile);
}
};
template <typename Case,
index_t MPerWave,
index_t NPerWave,
index_t KPerWave,
bool UseScale = true,
bool IsScale16 = false>
void RunWarpGemmCase(const ck_tile::HostTensor<typename Case::AType>& A,
const ck_tile::HostTensor<typename Case::BType>& B,
const ck_tile::HostTensor<ck_tile::e8m0_t>& ScaleA,
const ck_tile::HostTensor<ck_tile::e8m0_t>& ScaleB,
ck_tile::HostTensor<typename Case::AccType>& C)
{
ck_tile::DeviceMem Ad(A), Bd(B), Cd(C), SAd(ScaleA), SBd(ScaleB);
dim3 grid(1), block{64};
(void)ck_tile::launch_kernel(
ck_tile::stream_config{nullptr, true, 0, 0, 1},
ck_tile::make_kernel(
WarpGemmKernel<Case, MPerWave, NPerWave, KPerWave, UseScale, IsScale16>{},
grid,
block,
0,
Ad.GetDeviceBuffer(),
Bd.GetDeviceBuffer(),
Cd.GetDeviceBuffer(),
SAd.GetDeviceBuffer(),
SBd.GetDeviceBuffer()));
Cd.FromDevice(C.mData.data());
}
template <typename Case,
index_t MPerWave,
index_t NPerWave,
index_t KPerWave,
bool UseScale = true,
bool IsScale16 = false>
void RunCompareDispatcherAndReference()
{
using AType = typename Case::AType;
using BType = typename Case::BType;
using CType = typename Case::AccType;
constexpr index_t M = MPerWave;
constexpr index_t N = NPerWave;
constexpr index_t K = KPerWave;
const auto ScaleA = ck_tile::e8m0_t{2.f};
const auto ScaleB = ck_tile::e8m0_t{4.f};
ck_tile::HostTensor<AType> A({M, K});
ck_tile::HostTensor<BType> B({N, K});
ck_tile::HostTensor<CType> C({M, N});
ck_tile::HostTensor<ck_tile::e8m0_t> sA({M, 1});
ck_tile::HostTensor<ck_tile::e8m0_t> sB({N, 1});
ck_tile::FillUniformDistribution<AType>{-5.f, 5.f}(A);
ck_tile::FillUniformDistribution<BType>{-5.f, 5.f}(B);
C.SetZero();
ck_tile::FillConstant<ck_tile::e8m0_t>{ScaleA}(sA);
ck_tile::FillConstant<ck_tile::e8m0_t>{ScaleB}(sB);
RunWarpGemmCase<Case, MPerWave, NPerWave, KPerWave, UseScale, IsScale16>(A, B, sA, sB, C);
ck_tile::HostTensor<CType> C_ref({M, N});
C_ref.SetZero();
if constexpr(UseScale)
{
ck_tile::reference_mx_gemm<AType, BType, ck_tile::e8m0_t, ck_tile::e8m0_t, CType, CType>(
A, B.transpose(), C_ref, sA, sB.transpose());
}
else
{
ck_tile::reference_gemm<AType, BType, CType, CType>(A, B.transpose(), C_ref);
}
EXPECT_TRUE(ck_tile::check_err(C, C_ref, "Warp gemm result error."));
}
} // namespace ck_tile::test::warp_gemm