mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-11 08:48:45 +00:00
[rocm-libraries] ROCm/rocm-libraries#7677 (commit 308af93)
[CK_Tile] Add scale16 Support for F4 WMMA in CK_Tile ## Motivation This PR adds CK Tile support for the scale16 F4 WMMA path on gfx1250 and improves warp GEMM unit test coverage/structure for gfx1250-specific cases. ## Technical Details - Scale16 support in warp GEMM dispatch and WMMA trait plumbing: added IsScale16 plumbing to warp GEMM dispatcher path - Warp GEMM test restructuring for gfx1250: added Warp GEMM gfx1250 coverage to verify all F4 WMMA paths ## Test Plan Run ./test_ck_tile_wg_32x16x128_fp4. ## Test Result ``` ./test_ck_tile_wg_32x16x128_fp4 [----------] Global test environment tear-down [==========] 3 tests from 1 test suite ran. (1751 ms total) [ PASSED ] 3 tests. ``` ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
8d97265896
commit
22a99f97e8
@@ -103,7 +103,101 @@ struct Packed4Scale
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ScaleType>
|
||||
struct Packed8Scale
|
||||
{
|
||||
using scale_type = ScaleType;
|
||||
using raw_type = uint64_t;
|
||||
using raw_scale_type = typename ScaleType::raw_type;
|
||||
|
||||
static constexpr int num_pack = 8;
|
||||
union
|
||||
{
|
||||
raw_type data_;
|
||||
raw_scale_type scales_[num_pack]; // Direct byte/element access
|
||||
};
|
||||
|
||||
// Constructors
|
||||
CK_TILE_HOST_DEVICE constexpr Packed8Scale() = default;
|
||||
CK_TILE_HOST_DEVICE constexpr Packed8Scale(raw_type val) : data_(val) {}
|
||||
CK_TILE_HOST_DEVICE constexpr Packed8Scale(
|
||||
float s0, float s1, float s2, float s3, float s4, float s5, float s6, float s7)
|
||||
{
|
||||
set_scales_from_float(s0, s1, s2, s3, s4, s5, s6, s7);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr Packed8Scale(ScaleType s0,
|
||||
ScaleType s1,
|
||||
ScaleType s2,
|
||||
ScaleType s3,
|
||||
ScaleType s4,
|
||||
ScaleType s5,
|
||||
ScaleType s6,
|
||||
ScaleType s7)
|
||||
{
|
||||
set_scales(s0, s1, s2, s3, s4, s5, s6, s7);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr void set_scales_from_float(
|
||||
float s0, float s1, float s2, float s3, float s4, float s5, float s6, float s7)
|
||||
{
|
||||
set_scales(ScaleType(s0),
|
||||
ScaleType(s1),
|
||||
ScaleType(s2),
|
||||
ScaleType(s3),
|
||||
ScaleType(s4),
|
||||
ScaleType(s5),
|
||||
ScaleType(s6),
|
||||
ScaleType(s7));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr void set_scales(ScaleType s0,
|
||||
ScaleType s1,
|
||||
ScaleType s2,
|
||||
ScaleType s3,
|
||||
ScaleType s4,
|
||||
ScaleType s5,
|
||||
ScaleType s6,
|
||||
ScaleType s7)
|
||||
{
|
||||
data_ = 0;
|
||||
pack_scale(s0, 7);
|
||||
pack_scale(s1, 6);
|
||||
pack_scale(s2, 5);
|
||||
pack_scale(s3, 4);
|
||||
pack_scale(s4, 3);
|
||||
pack_scale(s5, 2);
|
||||
pack_scale(s6, 1);
|
||||
pack_scale(s7, 0);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr operator raw_type() const { return data_; }
|
||||
CK_TILE_HOST_DEVICE constexpr raw_type& data() [[clang::lifetimebound]] { return data_; }
|
||||
CK_TILE_HOST_DEVICE constexpr raw_type data() const { return data_; }
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr float unpack_to_float(int i) const
|
||||
{
|
||||
return static_cast<float>(unpack_scale(i));
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr ScaleType unpack_scale(int i) const
|
||||
{
|
||||
return ScaleType(scales_[i]);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr void pack_from_float(float scale, int i)
|
||||
{
|
||||
pack_scale(ScaleType(scale), i);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr void pack_scale(ScaleType scale, int i)
|
||||
{
|
||||
scales_[i] = scale.get();
|
||||
}
|
||||
};
|
||||
|
||||
// Type alias for e8m0_t scales
|
||||
using Packed4Scale_E8M0 = Packed4Scale<e8m0_t>;
|
||||
using Packed8Scale_E8M0 = Packed8Scale<e8m0_t>;
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -234,12 +234,12 @@ struct WarpGemmAttributeWmma
|
||||
}
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
template <typename... Params>
|
||||
template <typename... Params, typename AScaleType, typename BScaleType>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const int32_t& a_scale,
|
||||
const AScaleType& a_scale,
|
||||
const BVecType& b_vec,
|
||||
const int32_t& b_scale) const
|
||||
const BScaleType& b_scale) const
|
||||
{
|
||||
if constexpr(kTransC)
|
||||
{
|
||||
@@ -253,11 +253,11 @@ struct WarpGemmAttributeWmma
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
template <typename... Params>
|
||||
template <typename... Params, typename AScaleType, typename BScaleType>
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec,
|
||||
const int32_t& a_scale,
|
||||
const AScaleType& a_scale,
|
||||
const BVecType& b_vec,
|
||||
const int32_t& b_scale) const
|
||||
const BScaleType& b_scale) const
|
||||
{
|
||||
if constexpr(kTransC)
|
||||
{
|
||||
|
||||
@@ -19,6 +19,11 @@ template <typename Arch,
|
||||
typename MXTypeEnable = void>
|
||||
struct WmmaTraits;
|
||||
|
||||
// Tag used to select scale16 WMMA traits specializations.
|
||||
struct WmmaScale16Tag
|
||||
{
|
||||
};
|
||||
|
||||
// Generic WMMA implementation using traits
|
||||
template <typename Traits>
|
||||
struct WarpGemmAttributeWmmaImpl
|
||||
@@ -88,22 +93,22 @@ struct WarpGemmAttributeWmmaImpl
|
||||
Traits::template wmma_intrinsic<Params...>(a_vec, b_vec, CVecType{0.f}));
|
||||
}
|
||||
|
||||
template <typename... Params>
|
||||
template <typename... Params, typename AScaleType, typename BScaleType>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const int32_t& a_scale,
|
||||
const AScaleType& a_scale,
|
||||
const BVecType& b_vec,
|
||||
const int32_t& b_scale) const
|
||||
const BScaleType& b_scale) const
|
||||
{
|
||||
c_vec = Traits::template wmma_intrinsic<Params...>(a_vec, a_scale, b_vec, b_scale, c_vec);
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
template <typename... Params>
|
||||
template <typename... Params, typename AScaleType, typename BScaleType>
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec,
|
||||
const int32_t& a_scale,
|
||||
const AScaleType& a_scale,
|
||||
const BVecType& b_vec,
|
||||
const int32_t& b_scale) const
|
||||
const BScaleType& b_scale) const
|
||||
{
|
||||
return bit_cast<CVecType>(Traits::template wmma_intrinsic<Params...>(
|
||||
a_vec, a_scale, b_vec, b_scale, CVecType{0.f}));
|
||||
@@ -177,6 +182,9 @@ using WarpGemmAttributeWmmaImpl_f32_32x16x128_f4 =
|
||||
using WarpGemmAttributeWmmaImpl_f32_32x32x128_f4 =
|
||||
WarpGemmAttributeWmmaImpl<WmmaTraits<gfx125_t, pk_fp4_t, pk_fp4_t, float, 32, 32, 128>>;
|
||||
|
||||
using WarpGemmAttributeWmmaImpl_f32_32x32x128_f4_scale16 = WarpGemmAttributeWmmaImpl<
|
||||
WmmaTraits<gfx125_t, pk_fp4_t, pk_fp4_t, float, 32, 32, 128, WmmaScale16Tag>>;
|
||||
|
||||
using WarpGemmAttributeWmmaImpl_f16_16x16x64_f8_f8 =
|
||||
WarpGemmAttributeWmmaImpl<WmmaTraits<gfx125_t, fp8_t, fp8_t, fp16_t, 16, 16, 64>>;
|
||||
|
||||
|
||||
@@ -6,6 +6,9 @@
|
||||
#include "warp_gemm_attribute_wmma_impl_base_traits.hpp"
|
||||
#include "warp_gemm_params.hpp"
|
||||
namespace ck_tile {
|
||||
|
||||
struct WmmaScale16Tag;
|
||||
|
||||
// int8 specialization - GFX11
|
||||
template <>
|
||||
struct WmmaTraits<gfx11_t, int8_t, int8_t, int32_t, 16, 16, 16>
|
||||
@@ -528,17 +531,18 @@ struct WmmaTraits<gfx125_t, pk_fp4_t, pk_fp4_t, float, 32, 16, 128>
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct WmmaTraits<gfx125_t, pk_fp4_t, pk_fp4_t, float, 32, 32, 128>
|
||||
template <bool IsScale16>
|
||||
struct WmmaTraitsGfx125PkFp4F32_32x32x128
|
||||
: WmmaTraitsBase<gfx12_t, pk_fp4_t, pk_fp4_t, float, 128, false, 32, 32>
|
||||
{
|
||||
using ArchType = gfx125_t;
|
||||
using ArchType = gfx125_t;
|
||||
using ScaleType = std::conditional_t<IsScale16, int64_t, int32_t>;
|
||||
|
||||
template <typename... Params>
|
||||
CK_TILE_DEVICE static CVecType wmma_intrinsic(const AVecType& a_vec,
|
||||
const int32_t& a_scale,
|
||||
const ScaleType& a_scale,
|
||||
const BVecType& b_vec,
|
||||
const int32_t& b_scale,
|
||||
const ScaleType& b_scale,
|
||||
const CVecType& c_vec)
|
||||
{
|
||||
#ifdef __gfx125__
|
||||
@@ -569,19 +573,38 @@ struct WmmaTraits<gfx125_t, pk_fp4_t, pk_fp4_t, float, 32, 32, 128>
|
||||
const auto& b_slice = b_buffer.template get_as<BSliceType>()[n];
|
||||
auto& c_slice = c_result.template get_as<CSliceType>()[n];
|
||||
|
||||
c_slice = __builtin_amdgcn_wmma_scale_f32_32x16x128_f4(
|
||||
bit_cast<int32x16_t>(a_slice),
|
||||
bit_cast<int32x8_t>(b_slice),
|
||||
0,
|
||||
c_slice,
|
||||
1, // OPSEL[0] - fixed to 1 for F4
|
||||
P::scale_a, // OPSEL_HI[0] - scale data type for A
|
||||
a_scale,
|
||||
n.value, // OPSEL[1] - select B scale (iterates over N blocks)
|
||||
P::scale_b, // OPSEL_HI[1] - scale data type for B
|
||||
b_scale,
|
||||
0, // NEG
|
||||
0); // NEG_HI
|
||||
if constexpr(IsScale16)
|
||||
{
|
||||
c_slice = __builtin_amdgcn_wmma_scale16_f32_32x16x128_f4(
|
||||
bit_cast<int32x16_t>(a_slice),
|
||||
bit_cast<int32x8_t>(b_slice),
|
||||
0,
|
||||
c_slice,
|
||||
1, // OPSEL[0] - fixed to 1 for F4
|
||||
P::scale_a, // OPSEL_HI[0] - scale data type for A
|
||||
a_scale,
|
||||
n.value, // OPSEL[1] - select B scale (iterates over N blocks)
|
||||
P::scale_b, // OPSEL_HI[1] - scale data type for B
|
||||
b_scale,
|
||||
0, // NEG
|
||||
0); // NEG_HI
|
||||
}
|
||||
else
|
||||
{
|
||||
c_slice = __builtin_amdgcn_wmma_scale_f32_32x16x128_f4(
|
||||
bit_cast<int32x16_t>(a_slice),
|
||||
bit_cast<int32x8_t>(b_slice),
|
||||
0,
|
||||
c_slice,
|
||||
1, // OPSEL[0] - fixed to 1 for F4
|
||||
P::scale_a, // OPSEL_HI[0] - scale data type for A
|
||||
a_scale,
|
||||
n.value, // OPSEL[1] - select B scale (iterates over N blocks)
|
||||
P::scale_b, // OPSEL_HI[1] - scale data type for B
|
||||
b_scale,
|
||||
0, // NEG
|
||||
0); // NEG_HI
|
||||
}
|
||||
});
|
||||
|
||||
return bit_cast<CVecType>(c_result);
|
||||
@@ -602,7 +625,8 @@ struct WmmaTraits<gfx125_t, pk_fp4_t, pk_fp4_t, float, 32, 32, 128>
|
||||
#ifdef __gfx125__
|
||||
// Pass default scale values 1.0f
|
||||
Packed4Scale_E8M0 pkscale(1.0f, 1.0f, 1.0f, 1.0f);
|
||||
return wmma_intrinsic(a_vec, pkscale, b_vec, pkscale, c_vec);
|
||||
const auto default_scale = static_cast<ScaleType>(pkscale);
|
||||
return wmma_intrinsic(a_vec, default_scale, b_vec, default_scale, c_vec);
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
@@ -612,6 +636,18 @@ struct WmmaTraits<gfx125_t, pk_fp4_t, pk_fp4_t, float, 32, 32, 128>
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct WmmaTraits<gfx125_t, pk_fp4_t, pk_fp4_t, float, 32, 32, 128>
|
||||
: WmmaTraitsGfx125PkFp4F32_32x32x128<false>
|
||||
{
|
||||
};
|
||||
|
||||
template <>
|
||||
struct WmmaTraits<gfx125_t, pk_fp4_t, pk_fp4_t, float, 32, 32, 128, WmmaScale16Tag>
|
||||
: WmmaTraitsGfx125PkFp4F32_32x32x128<true>
|
||||
{
|
||||
};
|
||||
|
||||
// f8f6f4 specialization - GFX125
|
||||
enum F8F6F4OpDataTypeEnum
|
||||
{
|
||||
|
||||
@@ -33,6 +33,7 @@ template <typename AType,
|
||||
bool UseStructuredSparsity = false,
|
||||
WGAttrNumAccessEnum AttrNumAccessA = ESingle,
|
||||
WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA,
|
||||
bool IsScale16 = false,
|
||||
typename Enable = void>
|
||||
struct Dispatcher;
|
||||
|
||||
@@ -178,10 +179,10 @@ template<> struct Dispatcher<bf8_t, bf8_t, float, 32, 32, 16, true> { using Ty
|
||||
|
||||
#if !defined(__gfx125__)
|
||||
// scale mfma based f8f6f4
|
||||
template<typename A, typename B, WGAttrNumAccessEnum I>
|
||||
struct Dispatcher<A, B, float, 16, 16, 128, false, false, false, I, I, std::enable_if_t<I != EDefault>> { using Type = WarpGemmMfma_f32_16x16x128_f8f6f4<A, B, I>; };
|
||||
template<typename A, typename B, WGAttrNumAccessEnum I>
|
||||
struct Dispatcher<A, B, float, 16, 16, 128, true, false, false, I, I, std::enable_if_t<I != EDefault>> { using Type = WarpGemmMfma_f32_16x16x128_f8f6f4_CTransposed<A, B, I>; };
|
||||
template<typename A, typename B, WGAttrNumAccessEnum I, bool IsScale16>
|
||||
struct Dispatcher<A, B, float, 16, 16, 128, false, false, false, I, I, IsScale16, std::enable_if_t<I != EDefault>> { using Type = WarpGemmMfma_f32_16x16x128_f8f6f4<A, B, I>; };
|
||||
template<typename A, typename B, WGAttrNumAccessEnum I, bool IsScale16>
|
||||
struct Dispatcher<A, B, float, 16, 16, 128, true, false, false, I, I, IsScale16, std::enable_if_t<I != EDefault>> { using Type = WarpGemmMfma_f32_16x16x128_f8f6f4_CTransposed<A, B, I>; };
|
||||
#endif
|
||||
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8<>; };
|
||||
@@ -224,7 +225,7 @@ template<bool TransposeC, WGAttrNumAccessEnum AttrNumAccess> struct Dispatcher<f
|
||||
template<bool TransposeC, WGAttrNumAccessEnum AttrNumAccess> struct Dispatcher<bf8_t, fp8_t, float, 16, 16, 64, TransposeC, false, false, AttrNumAccess, AttrNumAccess> : WmmaTag { using Type = WarpGemmWmma_f32_16x16x64_bf8_f8<TransposeC, AttrNumAccess>; };
|
||||
|
||||
template<bool TransposeC, WGAttrNumAccessEnum AttrNumAccess> struct Dispatcher<pk_fp4_t, pk_fp4_t, float, 32, 16, 128, TransposeC, false, false, AttrNumAccess, AttrNumAccess> : WmmaTag { using Type = WarpGemmWmma_f32_32x16x128_f4<TransposeC, AttrNumAccess>; };
|
||||
template<bool TransposeC, WGAttrNumAccessEnum AttrNumAccess> struct Dispatcher<pk_fp4_t, pk_fp4_t, float, 32, 32, 128, TransposeC, false, false, AttrNumAccess, AttrNumAccess> : WmmaTag { using Type = WarpGemmWmma_f32_32x32x128_f4<TransposeC, AttrNumAccess>; };
|
||||
template<bool TransposeC, WGAttrNumAccessEnum AttrNumAccess, bool IsScale16> struct Dispatcher<pk_fp4_t, pk_fp4_t, float, 32, 32, 128, TransposeC, false, false, AttrNumAccess, AttrNumAccess, IsScale16> : WmmaTag { using Type = WarpGemmWmma_f32_32x32x128_f4<TransposeC, AttrNumAccess, IsScale16>; };
|
||||
|
||||
#if defined(__gfx125__)
|
||||
template<bool TransposeC, WGAttrNumAccessEnum AttrNumAccess> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 64, TransposeC, false, false, AttrNumAccess, AttrNumAccess> : WmmaTag { using Type = WarpGemmWmma_f32_16x16x64_f8_f8<TransposeC, AttrNumAccess>; };
|
||||
@@ -244,8 +245,27 @@ template<> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 64, true> { using Typ
|
||||
template<> struct Dispatcher<bf8_t, bf8_t, float, 16, 16, 64, true> { using Type = WarpGemmMfma_f32_16x16x64_bf8_bf8_CTransposed; };
|
||||
#endif
|
||||
|
||||
template<typename A, typename B, bool TransposeC, WGAttrNumAccessEnum AttrNumAccessA, WGAttrNumAccessEnum AttrNumAccessB>
|
||||
struct Dispatcher<A, B, float, 32, 32, 128, TransposeC, false, false, AttrNumAccessA, AttrNumAccessB> : WmmaTag { using Type = WarpGemmWmma_f32_32x32x128_f8f6f4<A, B, TransposeC, AttrNumAccessA, AttrNumAccessB>; };
|
||||
template<typename A,
|
||||
typename B,
|
||||
bool TransposeC,
|
||||
WGAttrNumAccessEnum AttrNumAccessA,
|
||||
WGAttrNumAccessEnum AttrNumAccessB,
|
||||
bool IsScale16>
|
||||
struct Dispatcher<A,
|
||||
B,
|
||||
float,
|
||||
32,
|
||||
32,
|
||||
128,
|
||||
TransposeC,
|
||||
false,
|
||||
false,
|
||||
AttrNumAccessA,
|
||||
AttrNumAccessB,
|
||||
IsScale16> : WmmaTag
|
||||
{
|
||||
using Type = WarpGemmWmma_f32_32x32x128_f8f6f4<A, B, TransposeC, AttrNumAccessA, AttrNumAccessB>;
|
||||
};
|
||||
|
||||
template<bool TransposeC, WGAttrNumAccessEnum AttrNumAccess> struct Dispatcher<fp8_t, fp8_t, half_t, 16, 16, 64, TransposeC, false, false, AttrNumAccess, AttrNumAccess> : WmmaTag { using Type =WarpGemmWmma_f16_16x16x64_f8_f8<TransposeC, AttrNumAccess>; };
|
||||
template<bool TransposeC, WGAttrNumAccessEnum AttrNumAccess> struct Dispatcher<bf8_t, bf8_t, half_t, 16, 16, 64, TransposeC, false, false, AttrNumAccess, AttrNumAccess> : WmmaTag { using Type =WarpGemmWmma_f16_16x16x64_bf8_bf8<TransposeC, AttrNumAccess>; };
|
||||
@@ -265,12 +285,12 @@ template<bool TransposeC, WGAttrNumAccessEnum AttrNumAccess> struct Dispatcher<u
|
||||
|
||||
template <typename AType, typename BType, typename AccType,
|
||||
index_t M, index_t N, index_t K,
|
||||
bool TransposeC, bool SA, bool SS>
|
||||
bool TransposeC, bool SA, bool SS, bool IsScale16>
|
||||
struct Dispatcher<AType, BType, AccType, M, N, K, TransposeC, SA, SS,
|
||||
EDefault, EDefault,
|
||||
EDefault, EDefault, IsScale16,
|
||||
std::enable_if_t<!std::is_base_of_v<WmmaTag,
|
||||
Dispatcher<AType, BType, AccType, M, N, K, TransposeC, SA, SS, ESingle, ESingle, void>>>>
|
||||
: Dispatcher<AType, BType, AccType, M, N, K, TransposeC, SA, SS, ESingle, ESingle, void> {};
|
||||
Dispatcher<AType, BType, AccType, M, N, K, TransposeC, SA, SS, ESingle, ESingle, IsScale16, void>>>>
|
||||
: Dispatcher<AType, BType, AccType, M, N, K, TransposeC, SA, SS, ESingle, ESingle, IsScale16, void> {};
|
||||
|
||||
// clang-format on
|
||||
} // namespace warp_gemm_dispatcher
|
||||
@@ -286,7 +306,8 @@ template <typename AType,
|
||||
bool SwizzleA = false,
|
||||
bool UseStructuredSparsity = false,
|
||||
WGAttrNumAccessEnum AttrNumAccessA = WGAttrNumAccessEnum::Default,
|
||||
WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA>
|
||||
WGAttrNumAccessEnum AttrNumAccessB = AttrNumAccessA,
|
||||
bool IsScale16 = false>
|
||||
using WarpGemmDispatcher = typename impl::warp_gemm_dispatcher::Dispatcher< //
|
||||
AType,
|
||||
BType,
|
||||
@@ -298,6 +319,7 @@ using WarpGemmDispatcher = typename impl::warp_gemm_dispatcher::Dispatcher< //
|
||||
SwizzleA,
|
||||
UseStructuredSparsity,
|
||||
AttrNumAccessA,
|
||||
AttrNumAccessB>::Type;
|
||||
AttrNumAccessB,
|
||||
IsScale16>::Type;
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -90,12 +90,17 @@ struct WarpGemmImpl
|
||||
c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
|
||||
}
|
||||
|
||||
template <typename... Params, typename CTensor, typename ATensor, typename BTensor>
|
||||
template <typename... Params,
|
||||
typename CTensor,
|
||||
typename ATensor,
|
||||
typename BTensor,
|
||||
typename AScaleType,
|
||||
typename BScaleType>
|
||||
CK_TILE_DEVICE void operator()(CTensor& c,
|
||||
const ATensor& a,
|
||||
const BTensor& b,
|
||||
const int32_t& a_scale,
|
||||
const int32_t& b_scale) const
|
||||
const AScaleType& a_scale,
|
||||
const BScaleType& b_scale) const
|
||||
{
|
||||
static_assert(detail::is_similiar_distributed_tensor_v<CTensor, CWarpTensor> &&
|
||||
detail::is_similiar_distributed_tensor_v<ATensor, AWarpTensor> &&
|
||||
@@ -141,11 +146,15 @@ struct WarpGemmImpl
|
||||
return c;
|
||||
}
|
||||
|
||||
template <typename... Params, typename ATensor, typename BTensor>
|
||||
template <typename... Params,
|
||||
typename ATensor,
|
||||
typename BTensor,
|
||||
typename AScaleType,
|
||||
typename BScaleType>
|
||||
CK_TILE_DEVICE auto operator()(const ATensor& a,
|
||||
const BTensor& b,
|
||||
const int32_t& a_scale,
|
||||
const int32_t& b_scale) const
|
||||
const AScaleType& a_scale,
|
||||
const BScaleType& b_scale) const
|
||||
{
|
||||
using CTensor = CWarpTensor;
|
||||
static_assert(detail::is_similiar_distributed_tensor_v<ATensor, AWarpTensor> &&
|
||||
|
||||
@@ -187,12 +187,16 @@ using WarpGemmWmma_f32_32x16x128_f4 =
|
||||
AttrNumAccess,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <bool kTransC = false, WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Default>
|
||||
using WarpGemmWmma_f32_32x32x128_f4 =
|
||||
WarpGemmImpl<WarpGemmAttributeWmma<WarpGemmAttributeWmmaImpl_f32_32x32x128_f4,
|
||||
kTransC,
|
||||
AttrNumAccess,
|
||||
AttrNumAccess>>;
|
||||
template <bool kTransC = false,
|
||||
WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Default,
|
||||
bool IsScale16 = false>
|
||||
using WarpGemmWmma_f32_32x32x128_f4 = WarpGemmImpl<
|
||||
WarpGemmAttributeWmma<std::conditional_t<IsScale16,
|
||||
WarpGemmAttributeWmmaImpl_f32_32x32x128_f4_scale16,
|
||||
WarpGemmAttributeWmmaImpl_f32_32x32x128_f4>,
|
||||
kTransC,
|
||||
AttrNumAccess,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <typename AType,
|
||||
typename BType,
|
||||
|
||||
@@ -4,3 +4,7 @@
|
||||
if(GPU_TARGETS MATCHES "gfx95")
|
||||
add_gtest_executable(test_ck_tile_wg_16x16x128_fp4 test_f32_16x16x128_fp4.cpp)
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx125")
|
||||
add_gtest_executable(test_ck_tile_wg_32x16x128_fp4 test_f32_32x16x128_fp4.cpp)
|
||||
endif()
|
||||
|
||||
@@ -1,153 +1,19 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
|
||||
using namespace ck_tile;
|
||||
|
||||
template <typename A,
|
||||
typename B,
|
||||
typename Acc,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
bool TransposeC,
|
||||
bool SwizzleA = false,
|
||||
bool UseStructuredSparsity = false,
|
||||
WGAttrNumAccessEnum NA = WGAttrNumAccessEnum::Single>
|
||||
struct WGDispCase
|
||||
{
|
||||
using AType = A;
|
||||
using BType = B;
|
||||
using AccType = Acc;
|
||||
static constexpr index_t MPerWave = M;
|
||||
static constexpr index_t NPerWave = N;
|
||||
static constexpr index_t KPerWave = K;
|
||||
static constexpr bool kTransposeC = TransposeC;
|
||||
static constexpr bool kSwizzleA = SwizzleA;
|
||||
static constexpr bool kUSS = UseStructuredSparsity;
|
||||
static constexpr WGAttrNumAccessEnum kNA = NA;
|
||||
};
|
||||
#include "test_gemm_util.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
using WGDispatcherTypesList =
|
||||
::testing::Types<WGDispCase<ck_tile::pk_fp4_t, ck_tile::pk_fp4_t, float, 16, 16, 128, false>>;
|
||||
::testing::Types<ck_tile::test::warp_gemm::WGDispCase<ck_tile::pk_fp4_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
float,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
ck_tile::WGAttrNumAccessEnum::Single>>;
|
||||
|
||||
template <typename AType,
|
||||
typename BType,
|
||||
typename CType,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
bool TransposeC,
|
||||
bool SwizzleA,
|
||||
bool UseStructuredSparsity,
|
||||
WGAttrNumAccessEnum NumAccess>
|
||||
struct WarpGemmKernel
|
||||
{
|
||||
static constexpr int kBlockSize = 64;
|
||||
__device__ void operator()(void* A, void* B, void* C, void* ScaleA, void* ScaleB) const
|
||||
{
|
||||
using WarpGemm = ck_tile::WarpGemmDispatcher<AType,
|
||||
BType,
|
||||
CType,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
TransposeC,
|
||||
SwizzleA,
|
||||
UseStructuredSparsity,
|
||||
NumAccess>;
|
||||
// A: [M,K] row-major (packed)
|
||||
const auto a_view = ck_tile::make_naive_tensor_view<ck_tile::address_space_enum::global>(
|
||||
static_cast<AType*>(A),
|
||||
ck_tile::make_tuple(M, K),
|
||||
ck_tile::make_tuple(K, ck_tile::number<1>{}),
|
||||
ck_tile::number<K>{},
|
||||
ck_tile::number<1>{});
|
||||
// B: expose as logical [N,K] with strides (1, N) over the original row-major [K,N] buffer
|
||||
const auto b_view = ck_tile::make_naive_tensor_view<ck_tile::address_space_enum::global>(
|
||||
static_cast<BType*>(B),
|
||||
ck_tile::make_tuple(N, K),
|
||||
ck_tile::make_tuple(K, ck_tile::number<1>{}),
|
||||
ck_tile::number<K>{},
|
||||
ck_tile::number<1>{});
|
||||
// C: [M,N] row-major (packed)
|
||||
const auto c_view = ck_tile::make_naive_tensor_view<ck_tile::address_space_enum::global>(
|
||||
static_cast<CType*>(C),
|
||||
ck_tile::make_tuple(M, N),
|
||||
ck_tile::make_tuple(N, ck_tile::number<1>{}),
|
||||
ck_tile::number<N>{},
|
||||
ck_tile::number<1>{});
|
||||
|
||||
using AWarpTensor = typename WarpGemm::AWarpTensor;
|
||||
using BWarpTensor = typename WarpGemm::BWarpTensor;
|
||||
using CWarpTensor = typename WarpGemm::CWarpTensor;
|
||||
|
||||
constexpr auto a_len = AWarpTensor::get_tile_distribution().get_lengths();
|
||||
constexpr auto b_len = BWarpTensor::get_tile_distribution().get_lengths();
|
||||
constexpr auto c_len = CWarpTensor::get_tile_distribution().get_lengths();
|
||||
|
||||
auto a_win = ck_tile::make_tile_window(
|
||||
a_view, a_len, ck_tile::make_multi_index(0, 0), AWarpTensor::get_tile_distribution());
|
||||
auto b_win = ck_tile::make_tile_window(
|
||||
b_view, b_len, ck_tile::make_multi_index(0, 0), BWarpTensor::get_tile_distribution());
|
||||
auto c_win = ck_tile::make_tile_window(
|
||||
c_view, c_len, ck_tile::make_multi_index(0, 0), CWarpTensor::get_tile_distribution());
|
||||
|
||||
AWarpTensor a_tile;
|
||||
BWarpTensor b_tile;
|
||||
ck_tile::load_tile(a_tile, a_win);
|
||||
ck_tile::load_tile(b_tile, b_win);
|
||||
|
||||
auto scale_a = static_cast<int32_t>(static_cast<ck_tile::e8m0_t*>(ScaleA)[0].get());
|
||||
auto scale_b = static_cast<int32_t>(static_cast<ck_tile::e8m0_t*>(ScaleB)[0].get());
|
||||
|
||||
auto c_tile =
|
||||
WarpGemm{}.template operator()<OpSelA<0>, OpSelB<0>>(a_tile, b_tile, scale_a, scale_b);
|
||||
|
||||
ck_tile::store_tile(c_win, c_tile);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Case>
|
||||
static void RunWarpGemmCase(const ck_tile::HostTensor<typename Case::AType>& A,
|
||||
const ck_tile::HostTensor<typename Case::BType>& B,
|
||||
const ck_tile::HostTensor<e8m0_t>& ScaleA,
|
||||
const ck_tile::HostTensor<e8m0_t>& ScaleB,
|
||||
ck_tile::HostTensor<typename Case::AccType>& C)
|
||||
{
|
||||
ck_tile::DeviceMem Ad(A), Bd(B), Cd(C), SAd(ScaleA), SBd(ScaleB);
|
||||
dim3 grid(1), block{64};
|
||||
|
||||
using Kernel = WarpGemmKernel<typename Case::AType,
|
||||
typename Case::BType,
|
||||
typename Case::AccType,
|
||||
Case::MPerWave,
|
||||
Case::NPerWave,
|
||||
Case::KPerWave,
|
||||
Case::kTransposeC,
|
||||
Case::kSwizzleA,
|
||||
Case::kUSS,
|
||||
Case::kNA>;
|
||||
|
||||
(void)ck_tile::launch_kernel(ck_tile::stream_config{nullptr, true, 0, 0, 1},
|
||||
ck_tile::make_kernel(Kernel{},
|
||||
grid,
|
||||
block,
|
||||
0,
|
||||
Ad.GetDeviceBuffer(),
|
||||
Bd.GetDeviceBuffer(),
|
||||
Cd.GetDeviceBuffer(),
|
||||
SAd.GetDeviceBuffer(),
|
||||
SBd.GetDeviceBuffer()));
|
||||
|
||||
Cd.FromDevice(C.mData.data());
|
||||
}
|
||||
|
||||
template <typename Case>
|
||||
template <typename T>
|
||||
class WGRuntimeTest : public ::testing::Test
|
||||
{
|
||||
};
|
||||
@@ -156,38 +22,6 @@ TYPED_TEST_SUITE(WGRuntimeTest, WGDispatcherTypesList);
|
||||
|
||||
TYPED_TEST(WGRuntimeTest, Compare_Dispatcher_MakeWG)
|
||||
{
|
||||
using Case = TypeParam;
|
||||
|
||||
using AType = typename Case::AType;
|
||||
using BType = typename Case::BType;
|
||||
using CType = typename Case::AccType;
|
||||
using ck_tile::e8m0_t;
|
||||
|
||||
constexpr index_t M = Case::MPerWave;
|
||||
constexpr index_t N = Case::NPerWave;
|
||||
constexpr index_t K = Case::KPerWave;
|
||||
|
||||
auto ScaleA = e8m0_t{2.f};
|
||||
auto ScaleB = e8m0_t{4.f};
|
||||
|
||||
ck_tile::HostTensor<AType> A({M, K});
|
||||
ck_tile::HostTensor<BType> B({N, K});
|
||||
ck_tile::HostTensor<CType> C({M, N});
|
||||
ck_tile::HostTensor<e8m0_t> sA({M, 1});
|
||||
ck_tile::HostTensor<e8m0_t> sB({N, 1});
|
||||
|
||||
ck_tile::FillUniformDistribution<AType>{-5.f, 5.f}(A);
|
||||
ck_tile::FillUniformDistribution<BType>{-5.f, 5.f}(B);
|
||||
C.SetZero();
|
||||
ck_tile::FillConstant<e8m0_t>{ScaleA}(sA);
|
||||
ck_tile::FillConstant<e8m0_t>{ScaleB}(sB);
|
||||
|
||||
RunWarpGemmCase<Case>(A, B, sA, sB, C);
|
||||
|
||||
ck_tile::HostTensor<CType> C_ref({M, N});
|
||||
C_ref.SetZero();
|
||||
ck_tile::reference_mx_gemm<AType, BType, e8m0_t, e8m0_t, CType, CType>(
|
||||
A, B.transpose(), C_ref, sA, sB.transpose());
|
||||
|
||||
EXPECT_TRUE(ck_tile::check_err(C, C_ref, "Warp gemm result error."));
|
||||
ck_tile::test::warp_gemm::
|
||||
RunCompareDispatcherAndReference<TypeParam, 16, 16, 128, true, false>();
|
||||
}
|
||||
|
||||
39
test/ck_tile/warp_gemm/test_f32_32x16x128_fp4.cpp
Normal file
39
test/ck_tile/warp_gemm/test_f32_32x16x128_fp4.cpp
Normal file
@@ -0,0 +1,39 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "test_gemm_util.hpp"
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
using WGDispatcherTypesList =
|
||||
::testing::Types<ck_tile::test::warp_gemm::WGDispCase<ck_tile::pk_fp4_t,
|
||||
ck_tile::pk_fp4_t,
|
||||
float,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
ck_tile::WGAttrNumAccessEnum::Single>>;
|
||||
|
||||
template <typename T>
|
||||
class WGRuntimeTest : public ::testing::Test
|
||||
{
|
||||
};
|
||||
|
||||
TYPED_TEST_SUITE(WGRuntimeTest, WGDispatcherTypesList);
|
||||
|
||||
TYPED_TEST(WGRuntimeTest, Compare_Dispatcher_MakeWG_NonScaled)
|
||||
{
|
||||
ck_tile::test::warp_gemm::
|
||||
RunCompareDispatcherAndReference<TypeParam, 32, 16, 128, false, false>();
|
||||
}
|
||||
|
||||
TYPED_TEST(WGRuntimeTest, Compare_Dispatcher_MakeWG_Scale16)
|
||||
{
|
||||
ck_tile::test::warp_gemm::
|
||||
RunCompareDispatcherAndReference<TypeParam, 32, 32, 128, true, true>();
|
||||
}
|
||||
|
||||
TYPED_TEST(WGRuntimeTest, Compare_Dispatcher_MakeWG_Scale32)
|
||||
{
|
||||
ck_tile::test::warp_gemm::
|
||||
RunCompareDispatcherAndReference<TypeParam, 32, 32, 128, true, false>();
|
||||
}
|
||||
223
test/ck_tile/warp_gemm/test_gemm_util.hpp
Normal file
223
test/ck_tile/warp_gemm/test_gemm_util.hpp
Normal file
@@ -0,0 +1,223 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/core/numeric/mxfp_scale.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
|
||||
namespace ck_tile::test::warp_gemm {
|
||||
|
||||
template <typename A,
|
||||
typename B,
|
||||
typename Acc,
|
||||
bool TransposeC,
|
||||
bool SwizzleA = false,
|
||||
bool UseStructuredSparsity = false,
|
||||
WGAttrNumAccessEnum NA = WGAttrNumAccessEnum::Single>
|
||||
struct WGDispCase
|
||||
{
|
||||
using AType = A;
|
||||
using BType = B;
|
||||
using AccType = Acc;
|
||||
static constexpr bool kTransposeC = TransposeC;
|
||||
static constexpr bool kSwizzleA = SwizzleA;
|
||||
static constexpr bool kUSS = UseStructuredSparsity;
|
||||
static constexpr WGAttrNumAccessEnum kNA = NA;
|
||||
};
|
||||
|
||||
template <typename Case,
|
||||
index_t MPerWave,
|
||||
index_t NPerWave,
|
||||
index_t KPerWave,
|
||||
bool UseScale = true,
|
||||
bool IsScale16 = false>
|
||||
struct WarpGemmKernel
|
||||
{
|
||||
static constexpr int kBlockSize = 64;
|
||||
__device__ void operator()(void* A, void* B, void* C, void* ScaleA, void* ScaleB) const
|
||||
{
|
||||
using WarpGemm = ck_tile::WarpGemmDispatcher<typename Case::AType,
|
||||
typename Case::BType,
|
||||
typename Case::AccType,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
KPerWave,
|
||||
Case::kTransposeC,
|
||||
Case::kSwizzleA,
|
||||
Case::kUSS,
|
||||
Case::kNA,
|
||||
Case::kNA,
|
||||
IsScale16>;
|
||||
|
||||
const auto a_view = ck_tile::make_naive_tensor_view<ck_tile::address_space_enum::global>(
|
||||
static_cast<typename Case::AType*>(A),
|
||||
ck_tile::make_tuple(MPerWave, KPerWave),
|
||||
ck_tile::make_tuple(KPerWave, ck_tile::number<1>{}),
|
||||
ck_tile::number<KPerWave>{},
|
||||
ck_tile::number<1>{});
|
||||
const auto b_view = ck_tile::make_naive_tensor_view<ck_tile::address_space_enum::global>(
|
||||
static_cast<typename Case::BType*>(B),
|
||||
ck_tile::make_tuple(NPerWave, KPerWave),
|
||||
ck_tile::make_tuple(KPerWave, ck_tile::number<1>{}),
|
||||
ck_tile::number<KPerWave>{},
|
||||
ck_tile::number<1>{});
|
||||
const auto c_view = ck_tile::make_naive_tensor_view<ck_tile::address_space_enum::global>(
|
||||
static_cast<typename Case::AccType*>(C),
|
||||
ck_tile::make_tuple(MPerWave, NPerWave),
|
||||
ck_tile::make_tuple(NPerWave, ck_tile::number<1>{}),
|
||||
ck_tile::number<NPerWave>{},
|
||||
ck_tile::number<1>{});
|
||||
|
||||
using AWarpTensor = typename WarpGemm::AWarpTensor;
|
||||
using BWarpTensor = typename WarpGemm::BWarpTensor;
|
||||
using CWarpTensor = typename WarpGemm::CWarpTensor;
|
||||
|
||||
constexpr auto a_len = AWarpTensor::get_tile_distribution().get_lengths();
|
||||
constexpr auto b_len = BWarpTensor::get_tile_distribution().get_lengths();
|
||||
constexpr auto c_len = CWarpTensor::get_tile_distribution().get_lengths();
|
||||
|
||||
auto a_win = ck_tile::make_tile_window(
|
||||
a_view, a_len, ck_tile::make_multi_index(0, 0), AWarpTensor::get_tile_distribution());
|
||||
auto b_win = ck_tile::make_tile_window(
|
||||
b_view, b_len, ck_tile::make_multi_index(0, 0), BWarpTensor::get_tile_distribution());
|
||||
auto c_win = ck_tile::make_tile_window(
|
||||
c_view, c_len, ck_tile::make_multi_index(0, 0), CWarpTensor::get_tile_distribution());
|
||||
|
||||
AWarpTensor a_tile;
|
||||
BWarpTensor b_tile;
|
||||
ck_tile::load_tile(a_tile, a_win);
|
||||
ck_tile::load_tile(b_tile, b_win);
|
||||
|
||||
const auto c_tile = [&]() {
|
||||
if constexpr(UseScale)
|
||||
{
|
||||
using ScaleType = std::conditional_t<IsScale16, int64_t, int32_t>;
|
||||
const auto scale_a = static_cast<ck_tile::e8m0_t*>(ScaleA)[0];
|
||||
const auto scale_b = static_cast<ck_tile::e8m0_t*>(ScaleB)[0];
|
||||
const auto packed_scale_a = [&]() -> ScaleType {
|
||||
if constexpr(IsScale16)
|
||||
{
|
||||
Packed8Scale_E8M0 pkscale(
|
||||
scale_a, scale_a, scale_a, scale_a, scale_a, scale_a, scale_a, scale_a);
|
||||
return static_cast<ScaleType>(pkscale);
|
||||
}
|
||||
else
|
||||
{
|
||||
Packed4Scale_E8M0 pkscale(scale_a, scale_a, scale_a, scale_a);
|
||||
return static_cast<ScaleType>(pkscale);
|
||||
}
|
||||
}();
|
||||
const auto packed_scale_b = [&]() -> ScaleType {
|
||||
if constexpr(IsScale16)
|
||||
{
|
||||
Packed8Scale_E8M0 pkscale(
|
||||
scale_b, scale_b, scale_b, scale_b, scale_b, scale_b, scale_b, scale_b);
|
||||
return static_cast<ScaleType>(pkscale);
|
||||
}
|
||||
else
|
||||
{
|
||||
Packed4Scale_E8M0 pkscale(scale_b, scale_b, scale_b, scale_b);
|
||||
return static_cast<ScaleType>(pkscale);
|
||||
}
|
||||
}();
|
||||
return WarpGemm{}.template operator()<OpSelA<0>, OpSelB<0>>(
|
||||
a_tile, b_tile, packed_scale_a, packed_scale_b);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::ignore = ScaleA;
|
||||
ck_tile::ignore = ScaleB;
|
||||
return WarpGemm{}.template operator()<OpSelA<0>, OpSelB<0>>(a_tile, b_tile);
|
||||
}
|
||||
}();
|
||||
|
||||
ck_tile::store_tile(c_win, c_tile);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Case,
|
||||
index_t MPerWave,
|
||||
index_t NPerWave,
|
||||
index_t KPerWave,
|
||||
bool UseScale = true,
|
||||
bool IsScale16 = false>
|
||||
void RunWarpGemmCase(const ck_tile::HostTensor<typename Case::AType>& A,
|
||||
const ck_tile::HostTensor<typename Case::BType>& B,
|
||||
const ck_tile::HostTensor<ck_tile::e8m0_t>& ScaleA,
|
||||
const ck_tile::HostTensor<ck_tile::e8m0_t>& ScaleB,
|
||||
ck_tile::HostTensor<typename Case::AccType>& C)
|
||||
{
|
||||
ck_tile::DeviceMem Ad(A), Bd(B), Cd(C), SAd(ScaleA), SBd(ScaleB);
|
||||
dim3 grid(1), block{64};
|
||||
|
||||
(void)ck_tile::launch_kernel(
|
||||
ck_tile::stream_config{nullptr, true, 0, 0, 1},
|
||||
ck_tile::make_kernel(
|
||||
WarpGemmKernel<Case, MPerWave, NPerWave, KPerWave, UseScale, IsScale16>{},
|
||||
grid,
|
||||
block,
|
||||
0,
|
||||
Ad.GetDeviceBuffer(),
|
||||
Bd.GetDeviceBuffer(),
|
||||
Cd.GetDeviceBuffer(),
|
||||
SAd.GetDeviceBuffer(),
|
||||
SBd.GetDeviceBuffer()));
|
||||
|
||||
Cd.FromDevice(C.mData.data());
|
||||
}
|
||||
|
||||
template <typename Case,
|
||||
index_t MPerWave,
|
||||
index_t NPerWave,
|
||||
index_t KPerWave,
|
||||
bool UseScale = true,
|
||||
bool IsScale16 = false>
|
||||
void RunCompareDispatcherAndReference()
|
||||
{
|
||||
using AType = typename Case::AType;
|
||||
using BType = typename Case::BType;
|
||||
using CType = typename Case::AccType;
|
||||
|
||||
constexpr index_t M = MPerWave;
|
||||
constexpr index_t N = NPerWave;
|
||||
constexpr index_t K = KPerWave;
|
||||
|
||||
const auto ScaleA = ck_tile::e8m0_t{2.f};
|
||||
const auto ScaleB = ck_tile::e8m0_t{4.f};
|
||||
|
||||
ck_tile::HostTensor<AType> A({M, K});
|
||||
ck_tile::HostTensor<BType> B({N, K});
|
||||
ck_tile::HostTensor<CType> C({M, N});
|
||||
ck_tile::HostTensor<ck_tile::e8m0_t> sA({M, 1});
|
||||
ck_tile::HostTensor<ck_tile::e8m0_t> sB({N, 1});
|
||||
|
||||
ck_tile::FillUniformDistribution<AType>{-5.f, 5.f}(A);
|
||||
ck_tile::FillUniformDistribution<BType>{-5.f, 5.f}(B);
|
||||
C.SetZero();
|
||||
ck_tile::FillConstant<ck_tile::e8m0_t>{ScaleA}(sA);
|
||||
ck_tile::FillConstant<ck_tile::e8m0_t>{ScaleB}(sB);
|
||||
|
||||
RunWarpGemmCase<Case, MPerWave, NPerWave, KPerWave, UseScale, IsScale16>(A, B, sA, sB, C);
|
||||
|
||||
ck_tile::HostTensor<CType> C_ref({M, N});
|
||||
C_ref.SetZero();
|
||||
|
||||
if constexpr(UseScale)
|
||||
{
|
||||
ck_tile::reference_mx_gemm<AType, BType, ck_tile::e8m0_t, ck_tile::e8m0_t, CType, CType>(
|
||||
A, B.transpose(), C_ref, sA, sB.transpose());
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::reference_gemm<AType, BType, CType, CType>(A, B.transpose(), C_ref);
|
||||
}
|
||||
|
||||
EXPECT_TRUE(ck_tile::check_err(C, C_ref, "Warp gemm result error."));
|
||||
}
|
||||
|
||||
} // namespace ck_tile::test::warp_gemm
|
||||
Reference in New Issue
Block a user