mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK_TILE] Improve F8F6F4 Scaled WarpGemm (#3197)
* [CK_TILE] Improve F8F6F4 Scaled WarpGemm * Thanks, Copilot
This commit is contained in:
@@ -308,50 +308,50 @@ using WarpGemmMfma_f32_16x16x64_bf8_bf8 = WarpGemmImpl<WarpGemmAttributeMfmaIter
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_16x16x128_fp4 = WarpGemmImpl<
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaScaleImpl_f32_16x16x128_fp4<WGAttrCtlEnum::Default_>,
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<pk_fp4_t, pk_fp4_t>,
|
||||
AttrNumAccess>>;
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_16x16x128_fp8_fp8 = WarpGemmImpl<
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_fp8_fp8<WGAttrCtlEnum::Default_>,
|
||||
using WarpGemmMfma_f32_16x16x128_fp8_fp8 = WarpGemmImpl< //
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<fp8_t, fp8_t>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_16x16x128_fp8_bf8 = WarpGemmImpl<
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_fp8_bf8<WGAttrCtlEnum::Default_>,
|
||||
using WarpGemmMfma_f32_16x16x128_fp8_bf8 = WarpGemmImpl< //
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<fp8_t, bf8_t>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_16x16x128_bf8_fp8 = WarpGemmImpl<
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_fp8<WGAttrCtlEnum::Default_>,
|
||||
using WarpGemmMfma_f32_16x16x128_bf8_fp8 = WarpGemmImpl< //
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<bf8_t, fp8_t>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_16x16x128_bf8_bf8 = WarpGemmImpl<
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_bf8<WGAttrCtlEnum::Default_>,
|
||||
using WarpGemmMfma_f32_16x16x128_bf8_bf8 = WarpGemmImpl< //
|
||||
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<bf8_t, bf8_t>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_16x16x128_fp8_fp8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x128_fp8_fp8<WGAttrCtlEnum::Default_>,
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<fp8_t, fp8_t>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_16x16x128_fp8_bf8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x128_fp8_bf8<WGAttrCtlEnum::Default_>,
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<fp8_t, bf8_t>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_16x16x128_bf8_fp8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_fp8<WGAttrCtlEnum::Default_>,
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<bf8_t, fp8_t>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfma_f32_16x16x128_bf8_bf8_CTransposed =
|
||||
WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution<
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_bf8<WGAttrCtlEnum::Default_>,
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4<bf8_t, bf8_t>,
|
||||
AttrNumAccess>>;
|
||||
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
|
||||
@@ -1527,113 +1527,15 @@ using WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8 =
|
||||
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<bf8_t, bf8_t, Ctrl_>;
|
||||
|
||||
template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base
|
||||
struct WarpGemmAttributeMfmaImpl_f32_16x16x128_f8f6f4
|
||||
{
|
||||
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
|
||||
using ADataType = AType_;
|
||||
using BDataType = BType_;
|
||||
using CDataType = float;
|
||||
|
||||
using AVecType = ext_vector_t<ADataType, 32>;
|
||||
using BVecType = ext_vector_t<BDataType, 32>;
|
||||
using CVecType = ext_vector_t<CDataType, 4>;
|
||||
|
||||
static constexpr index_t kM = 16;
|
||||
static constexpr index_t kN = 16;
|
||||
static constexpr index_t kK = 128;
|
||||
|
||||
static constexpr index_t kAMBlock = 1;
|
||||
static constexpr index_t kBNBlock = 1;
|
||||
|
||||
static constexpr index_t kAMLane = 16;
|
||||
static constexpr index_t kBNLane = 16;
|
||||
static constexpr index_t kABKLane = 4;
|
||||
static constexpr index_t kABKPerLane = 32;
|
||||
|
||||
static constexpr index_t kCMLane = 4;
|
||||
static constexpr index_t kCNLane = 16;
|
||||
static constexpr index_t kCM0PerLane = 1;
|
||||
static constexpr index_t kCM1PerLane = 4;
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
template <bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
//__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, cbsz, blgp, opsel, scale_a,
|
||||
// opsel, scale_b)
|
||||
#if defined(__gfx950__)
|
||||
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
a_vec, b_vec, c_vec, 0, 0, 0, 0, 0, 0);
|
||||
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
a_vec, b_vec, c_vec, 0, 1, 0, 0, 0, 0);
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
a_vec, b_vec, c_vec, 1, 0, 0, 0, 0, 0);
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
a_vec, b_vec, c_vec, 1, 1, 0, 0, 0, 0);
|
||||
#else
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
#endif
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
a_vec, b_vec, CVecType{0.f}, 0, 0, 0, 0, 0, 0));
|
||||
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
a_vec, b_vec, CVecType{0.f}, 0, 1, 0, 0, 0, 0));
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
a_vec, b_vec, CVecType{0.f}, 1, 0, 0, 0, 0, 0));
|
||||
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
a_vec, b_vec, CVecType{0.f}, 1, 1, 0, 0, 0, 0));
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
return CVecType{0.f};
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImpl_f32_16x16x128_fp8_fp8 =
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base<fp8_t, fp8_t, Ctrl_>;
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImpl_f32_16x16x128_fp8_bf8 =
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base<fp8_t, bf8_t, Ctrl_>;
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_fp8 =
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base<bf8_t, fp8_t, Ctrl_>;
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
using WarpGemmAttributeMfmaImpl_f32_16x16x128_bf8_bf8 =
|
||||
WarpGemmAttributeMfmaImpl_f32_16x16x128_f8_bf8_base<bf8_t, bf8_t, Ctrl_>;
|
||||
|
||||
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
|
||||
struct WarpGemmAttributeMfmaScaleImpl_f32_16x16x128_fp4
|
||||
{
|
||||
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
|
||||
using ADataType = pk_fp4_t;
|
||||
using BDataType = pk_fp4_t;
|
||||
using CDataType = float;
|
||||
|
||||
using AVecType = ext_vector_t<ADataType, 16>;
|
||||
using BVecType = ext_vector_t<BDataType, 16>;
|
||||
using AVecType = ext_vector_t<ADataType, 32 / numeric_traits<ADataType>::PackedSize>;
|
||||
using BVecType = ext_vector_t<BDataType, 32 / numeric_traits<BDataType>::PackedSize>;
|
||||
using CVecType = ext_vector_t<CDataType, 4>;
|
||||
|
||||
static constexpr index_t kM = 16;
|
||||
@@ -1662,21 +1564,38 @@ struct WarpGemmAttributeMfmaScaleImpl_f32_16x16x128_fp4
|
||||
const int32_t& b_scale,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
//__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, cbsz, blgp, opsel, scale_a,
|
||||
// opsel, scale_b)
|
||||
#if defined(__gfx950__)
|
||||
auto arg_a = bit_cast<int32x4_t>(a_vec);
|
||||
auto arg_b = bit_cast<int32x4_t>(b_vec);
|
||||
c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
int32x8_t{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
|
||||
int32x8_t{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
|
||||
c_vec,
|
||||
4,
|
||||
4,
|
||||
opselA,
|
||||
a_scale,
|
||||
opselB,
|
||||
b_scale);
|
||||
auto dtype2conf = [](auto dtype) {
|
||||
if constexpr(std::is_same_v<decltype(dtype), fp8_t>)
|
||||
return make_tuple(number<0>{}, int32x8_t{});
|
||||
else if constexpr(std::is_same_v<decltype(dtype), bf8_t>)
|
||||
return make_tuple(number<1>{}, int32x8_t{});
|
||||
// else if e2m3 => make_tuple(number<2>{}, int32x6_t{})
|
||||
// else if e3m2 => make_tuple(number<3>{}, int32x6_t{})
|
||||
else if constexpr(std::is_same_v<decltype(dtype), pk_fp4_t>)
|
||||
return make_tuple(number<4>{}, int32x4_t{});
|
||||
else
|
||||
static_assert(false, "Unsupported data type for mfma scale");
|
||||
};
|
||||
auto dtype2code = [&](auto dtype) { return dtype2conf(dtype)(number<0>{}); };
|
||||
auto dtype2vec = [&](auto dtype) { return dtype2conf(dtype)(number<1>{}); };
|
||||
auto arg256 = [&](auto x) {
|
||||
if constexpr(sizeof(x) == 16)
|
||||
return int32x8_t{x[0], x[1], x[2], x[3], 0, 0, 0, 0};
|
||||
else if constexpr(sizeof(x) == 24)
|
||||
return int32x8_t{x[0], x[1], x[2], x[3], x[4], x[5], 0, 0};
|
||||
else if constexpr(sizeof(x) == 32)
|
||||
return x;
|
||||
else
|
||||
static_assert(false, "Unexpected vector size for mfma scale");
|
||||
};
|
||||
|
||||
auto arg_a = bit_cast<decltype(dtype2vec(ADataType{}))>(a_vec);
|
||||
auto arg_b = bit_cast<decltype(dtype2vec(BDataType{}))>(b_vec);
|
||||
constexpr int cbsz = decltype(dtype2code(ADataType{}))::value;
|
||||
constexpr int blgp = decltype(dtype2code(BDataType{}))::value;
|
||||
c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
arg256(arg_a), arg256(arg_b), c_vec, cbsz, blgp, opselA, a_scale, opselB, b_scale);
|
||||
#else
|
||||
ck_tile::ignore = c_vec;
|
||||
ck_tile::ignore = a_vec;
|
||||
@@ -1693,26 +1612,25 @@ struct WarpGemmAttributeMfmaScaleImpl_f32_16x16x128_fp4
|
||||
const BVecType& b_vec,
|
||||
const int32_t& b_scale) const
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
auto arg_a = bit_cast<int32x4_t>(a_vec);
|
||||
auto arg_b = bit_cast<int32x4_t>(b_vec);
|
||||
return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
int32x8_t{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
|
||||
int32x8_t{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
|
||||
CVecType{0.f},
|
||||
4,
|
||||
4,
|
||||
opselA,
|
||||
a_scale,
|
||||
opselB,
|
||||
b_scale));
|
||||
#else
|
||||
ck_tile::ignore = a_vec;
|
||||
ck_tile::ignore = b_vec;
|
||||
ck_tile::ignore = a_scale;
|
||||
ck_tile::ignore = b_scale;
|
||||
return CVecType{0.f};
|
||||
#endif
|
||||
CVecType c_vec{0.f};
|
||||
operator()<opselA, opselB>(c_vec, a_vec, a_scale, b_vec, b_scale);
|
||||
return c_vec;
|
||||
}
|
||||
|
||||
// c_vec += a_vec * b_vec
|
||||
template <bool post_nop_ = false>
|
||||
CK_TILE_DEVICE void operator()(CVecType& c_vec,
|
||||
const AVecType& a_vec,
|
||||
const BVecType& b_vec,
|
||||
bool_constant<post_nop_> = {}) const
|
||||
{
|
||||
operator()<0, 0>(c_vec, a_vec, 0, b_vec, 0);
|
||||
}
|
||||
|
||||
// c_vec = a_vec * b_vec
|
||||
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
|
||||
{
|
||||
return operator()<0, 0>(a_vec, 0, b_vec, 0);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -10,6 +10,13 @@
|
||||
namespace ck_tile {
|
||||
|
||||
namespace impl {
|
||||
namespace warp_gemm_dispatcher {
|
||||
|
||||
// C++20 using enum
|
||||
static inline constexpr auto ESingle = WGAttrNumAccessEnum::Single;
|
||||
static inline constexpr auto EDouble = WGAttrNumAccessEnum::Double;
|
||||
static inline constexpr auto EQuad = WGAttrNumAccessEnum::Quad;
|
||||
|
||||
template <typename AType,
|
||||
typename BType,
|
||||
typename AccType,
|
||||
@@ -19,154 +26,132 @@ template <typename AType,
|
||||
bool TransposeC,
|
||||
bool SwizzleA = false,
|
||||
bool UseStructuredSparsity = false,
|
||||
WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
struct WarpGemmDispatcher;
|
||||
WGAttrNumAccessEnum AttrNumAccess = ESingle>
|
||||
struct Dispatcher;
|
||||
|
||||
// clang-format off
|
||||
// fp32
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
template<> struct WarpGemmDispatcher<float, float, float, 16, 16, 4, false> { using Type = WarpGemmMfmaF32F32F32M16N16K4; };
|
||||
template<> struct WarpGemmDispatcher<float, float, float, 16, 16, 16, false> { using Type = WarpGemmMfmaF32F32F32M16N16K16<>; };
|
||||
template<> struct WarpGemmDispatcher<float, float, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF32F32F32M16N16K16TransposedCDistribution<>; };
|
||||
template<> struct Dispatcher<float, float, float, 16, 16, 4, false> { using Type = WarpGemmMfmaF32F32F32M16N16K4; };
|
||||
template<> struct Dispatcher<float, float, float, 16, 16, 16, false> { using Type = WarpGemmMfmaF32F32F32M16N16K16<>; };
|
||||
template<> struct Dispatcher<float, float, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF32F32F32M16N16K16TransposedCDistribution<>; };
|
||||
// fp16
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaF16F16F32M32N32K8; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaF16F16F32M32N32K16<>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution<>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, false, false, WGAttrNumAccessEnum::Double> {
|
||||
using Type = WarpGemmMfmaF16F16F32M32N32K16<WGAttrNumAccessEnum::Double>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, true, false, false, WGAttrNumAccessEnum::Double> {
|
||||
using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution<WGAttrNumAccessEnum::Double>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaF16F16F32M16N16K32<>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false, false, false, WGAttrNumAccessEnum::Double> {
|
||||
using Type = WarpGemmMfmaF16F16F32M16N16K32<WGAttrNumAccessEnum::Double>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true, false, false, WGAttrNumAccessEnum::Double> {
|
||||
using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<WGAttrNumAccessEnum::Double>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 4, 64, 16, false> { using Type = WarpGemmMfmaF16F16F32M4N64K16; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 64, 4, 16, false> { using Type = WarpGemmMfmaF16F16F32M64N4K16; };
|
||||
template<> struct Dispatcher<half_t, half_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaF16F16F32M32N32K8; };
|
||||
template<> struct Dispatcher<half_t, half_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; };
|
||||
template<> struct Dispatcher<half_t, half_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaF16F16F32M32N32K16<>; };
|
||||
template<> struct Dispatcher<half_t, half_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution<>; };
|
||||
template<> struct Dispatcher<half_t, half_t, float, 32, 32, 16, false, false, false, EDouble> { using Type = WarpGemmMfmaF16F16F32M32N32K16<EDouble>; };
|
||||
template<> struct Dispatcher<half_t, half_t, float, 32, 32, 16, true, false, false, EDouble> { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution<EDouble>; };
|
||||
template<> struct Dispatcher<half_t, half_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaF16F16F32M16N16K32<>; };
|
||||
template<> struct Dispatcher<half_t, half_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<>; };
|
||||
template<> struct Dispatcher<half_t, half_t, float, 16, 16, 32, false, false, false, EDouble> { using Type = WarpGemmMfmaF16F16F32M16N16K32<EDouble>; };
|
||||
template<> struct Dispatcher<half_t, half_t, float, 16, 16, 32, true, false, false, EDouble> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<EDouble>; };
|
||||
template<> struct Dispatcher<half_t, half_t, float, 4, 64, 16, false> { using Type = WarpGemmMfmaF16F16F32M4N64K16; };
|
||||
template<> struct Dispatcher<half_t, half_t, float, 64, 4, 16, false> { using Type = WarpGemmMfmaF16F16F32M64N4K16; };
|
||||
// WMMA cases
|
||||
#if defined(__gfx11__) || defined(__gfx12__)
|
||||
template<bool TransposeC> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, TransposeC, false> { using Type = WarpGemmWmma_f32_16x16x16_f16_f16<TransposeC>;};
|
||||
template<bool TransposeC> struct Dispatcher<half_t, half_t, float, 16, 16, 16, TransposeC, false> { using Type = WarpGemmWmma_f32_16x16x16_f16_f16<TransposeC>;};
|
||||
#else
|
||||
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaF16F16F32M16N16K16; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; };
|
||||
template<> struct Dispatcher<half_t, half_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaF16F16F32M16N16K16; };
|
||||
template<> struct Dispatcher<half_t, half_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; };
|
||||
#endif
|
||||
|
||||
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, true, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleBTransposedCDistribution; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, true, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution; };
|
||||
template<> struct Dispatcher<half_t, half_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; };
|
||||
template<> struct Dispatcher<half_t, half_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; };
|
||||
template<> struct Dispatcher<half_t, half_t, float, 32, 32, 8, true, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleBTransposedCDistribution; };
|
||||
template<> struct Dispatcher<half_t, half_t, float, 32, 32, 16, true, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution; };
|
||||
|
||||
// fp16 2:4 structural sparsity
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, false, true> { using Type = WarpGemmSmfmacF16F16F32M32N32K16; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false, false, true> { using Type = WarpGemmSmfmacF16F16F32M16N16K32; };
|
||||
template<> struct Dispatcher<half_t, half_t, float, 32, 32, 16, false, false, true> { using Type = WarpGemmSmfmacF16F16F32M32N32K16; };
|
||||
template<> struct Dispatcher<half_t, half_t, float, 16, 16, 32, false, false, true> { using Type = WarpGemmSmfmacF16F16F32M16N16K32; };
|
||||
|
||||
// bf16
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16<>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution<>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false, false, false, WGAttrNumAccessEnum::Double> {
|
||||
using Type = WarpGemmMfmaBf16Bf16F32M32N32K16<WGAttrNumAccessEnum::Double>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, true, false, false, WGAttrNumAccessEnum::Double> {
|
||||
using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution<WGAttrNumAccessEnum::Double>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32<>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution<>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false, false, false, WGAttrNumAccessEnum::Double> {
|
||||
using Type = WarpGemmMfmaBf16Bf16F32M16N16K32<WGAttrNumAccessEnum::Double>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, true, false, false, WGAttrNumAccessEnum::Double> {
|
||||
using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution<WGAttrNumAccessEnum::Double>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 4, 64, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M4N64K16; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 64, 4, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M64N4K16; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16<>; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution<>; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 32, 32, 16, false, false, false, EDouble> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16<EDouble>; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 32, 32, 16, true, false, false, EDouble> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution<EDouble>; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32<>; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution<>; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 32, false, false, false, EDouble> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32<EDouble>; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 32, true, false, false, EDouble> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution<EDouble>; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 4, 64, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M4N64K16; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 64, 4, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M64N4K16; };
|
||||
// WMMA cases
|
||||
#if defined(__gfx11__) || defined(__gfx12__)
|
||||
template<bool TransposeC> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, TransposeC, false> { using Type = WarpGemmWmma_f32_16x16x16_bf16_bf16<TransposeC>; };
|
||||
template<bool TransposeC> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 16, TransposeC, false> { using Type = WarpGemmWmma_f32_16x16x16_bf16_bf16<TransposeC>; };
|
||||
#else
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; };
|
||||
#endif
|
||||
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, true, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleBTransposedCDistribution; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, true, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 32, 32, 8, true, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleBTransposedCDistribution; };
|
||||
template<> struct Dispatcher<bf16_t, bf16_t, float, 32, 32, 16, true, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution; };
|
||||
|
||||
// fp8
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 32, false> { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 16, 16, 32, false> { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 16, 16, 64, false> { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 16, 16, 32, true> { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8_CTransposed; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 16, 16, 32, false> { using Type = WarpGemmMfma_f32_16x16x32_fp8_bf8; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 32, false> { using Type = WarpGemmMfma_f32_32x32x32_fp8_bf8; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 32, false> { using Type = WarpGemmMfma_f32_32x32x32_bf8_bf8; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 32, false> { using Type = WarpGemmMfma_f32_16x16x32_bf8_bf8; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 32, true> { using Type = WarpGemmMfma_f32_16x16x32_bf8_bf8_CTransposed; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 64, false> { using Type = WarpGemmMfma_f32_16x16x64_bf8_bf8; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 16, 16, 128, false> { using Type = WarpGemmMfma_f32_16x16x128_fp8_fp8<>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 16, 16, 128, false> { using Type = WarpGemmMfma_f32_16x16x128_fp8_bf8<>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 16, 16, 128, false> { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8<>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 128, false> { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8<>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 16, 16, 128, true> { using Type = WarpGemmMfma_f32_16x16x128_fp8_fp8_CTransposed<>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 16, 16, 128, true> { using Type = WarpGemmMfma_f32_16x16x128_fp8_bf8_CTransposed<>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 16, 16, 128, true> { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8_CTransposed<>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 128, true> { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8_CTransposed<>; };
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; };
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 32, false> { using Type = WarpGemmMfma_f32_32x32x32_fp8_fp8; };
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 32, false> { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8; };
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 64, false> { using Type = WarpGemmMfma_f32_16x16x64_fp8_fp8; };
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; };
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 32, true> { using Type = WarpGemmMfma_f32_16x16x32_fp8_fp8_CTransposed; };
|
||||
template<> struct Dispatcher<fp8_t, bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; };
|
||||
template<> struct Dispatcher<fp8_t, bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed; };
|
||||
template<> struct Dispatcher<fp8_t, bf8_t, float, 16, 16, 32, false> { using Type = WarpGemmMfma_f32_16x16x32_fp8_bf8; };
|
||||
template<> struct Dispatcher<fp8_t, bf8_t, float, 32, 32, 32, false> { using Type = WarpGemmMfma_f32_32x32x32_fp8_bf8; };
|
||||
template<> struct Dispatcher<bf8_t, fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; };
|
||||
template<> struct Dispatcher<bf8_t, fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; };
|
||||
template<> struct Dispatcher<bf8_t, bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8; };
|
||||
template<> struct Dispatcher<bf8_t, bf8_t, float, 32, 32, 32, false> { using Type = WarpGemmMfma_f32_32x32x32_bf8_bf8; };
|
||||
template<> struct Dispatcher<bf8_t, bf8_t, float, 16, 16, 32, false> { using Type = WarpGemmMfma_f32_16x16x32_bf8_bf8; };
|
||||
template<> struct Dispatcher<bf8_t, bf8_t, float, 16, 16, 32, true> { using Type = WarpGemmMfma_f32_16x16x32_bf8_bf8_CTransposed; };
|
||||
template<> struct Dispatcher<bf8_t, bf8_t, float, 16, 16, 64, false> { using Type = WarpGemmMfma_f32_16x16x64_bf8_bf8; };
|
||||
template<> struct Dispatcher<bf8_t, bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; };
|
||||
|
||||
template<> struct WarpGemmDispatcher<ck_tile::pk_fp4_t, ck_tile::pk_fp4_t, float, 16, 16, 128, false> { using Type = WarpGemmMfma_f32_16x16x128_fp4<>; };
|
||||
// scale mfma based f8f6f4
|
||||
template<WGAttrNumAccessEnum I> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 128, false, false, false, I> { using Type = WarpGemmMfma_f32_16x16x128_fp8_fp8<I>; };
|
||||
template<WGAttrNumAccessEnum I> struct Dispatcher<fp8_t, bf8_t, float, 16, 16, 128, false, false, false, I> { using Type = WarpGemmMfma_f32_16x16x128_fp8_bf8<I>; };
|
||||
template<WGAttrNumAccessEnum I> struct Dispatcher<bf8_t, fp8_t, float, 16, 16, 128, false, false, false, I> { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8<I>; };
|
||||
template<WGAttrNumAccessEnum I> struct Dispatcher<bf8_t, bf8_t, float, 16, 16, 128, false, false, false, I> { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8<I>; };
|
||||
template<WGAttrNumAccessEnum I> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 128, true, false, false, I> { using Type = WarpGemmMfma_f32_16x16x128_fp8_fp8_CTransposed<I>; };
|
||||
template<WGAttrNumAccessEnum I> struct Dispatcher<fp8_t, bf8_t, float, 16, 16, 128, true, false, false, I> { using Type = WarpGemmMfma_f32_16x16x128_fp8_bf8_CTransposed<I>; };
|
||||
template<WGAttrNumAccessEnum I> struct Dispatcher<bf8_t, fp8_t, float, 16, 16, 128, true, false, false, I> { using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8_CTransposed<I>; };
|
||||
template<WGAttrNumAccessEnum I> struct Dispatcher<bf8_t, bf8_t, float, 16, 16, 128, true, false, false, I> { using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8_CTransposed<I>; };
|
||||
template<> struct Dispatcher<pk_fp4_t, pk_fp4_t, float, 16, 16, 128, false> { using Type = WarpGemmMfma_f32_16x16x128_fp4<>; };
|
||||
|
||||
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8<>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8<>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8<>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8<>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 64, false, false, false, WGAttrNumAccessEnum::Quad> {
|
||||
using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8<WGAttrNumAccessEnum::Quad>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 64, false, false, false, WGAttrNumAccessEnum::Quad> {
|
||||
using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8<WGAttrNumAccessEnum::Quad>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 64, false, false, false, WGAttrNumAccessEnum::Quad> {
|
||||
using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8<WGAttrNumAccessEnum::Quad>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 64, false, false, false, WGAttrNumAccessEnum::Quad> {
|
||||
using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8<WGAttrNumAccessEnum::Quad>; };
|
||||
|
||||
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 16, 16, 128, false, false, false, WGAttrNumAccessEnum::Quad> {
|
||||
using Type = WarpGemmMfma_f32_16x16x128_fp8_fp8<WGAttrNumAccessEnum::Quad>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 16, 16, 128, false, false, false, WGAttrNumAccessEnum::Quad> {
|
||||
using Type = WarpGemmMfma_f32_16x16x128_fp8_bf8<WGAttrNumAccessEnum::Quad>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 16, 16, 128, false, false, false, WGAttrNumAccessEnum::Quad> {
|
||||
using Type = WarpGemmMfma_f32_16x16x128_bf8_fp8<WGAttrNumAccessEnum::Quad>; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 128, false, false, false, WGAttrNumAccessEnum::Quad> {
|
||||
using Type = WarpGemmMfma_f32_16x16x128_bf8_bf8<WGAttrNumAccessEnum::Quad>; };
|
||||
|
||||
template<> struct WarpGemmDispatcher<ck_tile::pk_fp4_t, ck_tile::pk_fp4_t, float, 16, 16, 128, false, false, false, WGAttrNumAccessEnum::Quad> {
|
||||
using Type = WarpGemmMfma_f32_16x16x128_fp4<WGAttrNumAccessEnum::Quad>; };
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8<>; };
|
||||
template<> struct Dispatcher<fp8_t, bf8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8<>; };
|
||||
template<> struct Dispatcher<bf8_t, fp8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8<>; };
|
||||
template<> struct Dispatcher<bf8_t, bf8_t, float, 32, 32, 64, false> { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8<>; };
|
||||
template<> struct Dispatcher<fp8_t, fp8_t, float, 32, 32, 64, false, false, false, EQuad> { using Type = WarpGemmMfma_f32_32x32x64_fp8_fp8<EQuad>; };
|
||||
template<> struct Dispatcher<fp8_t, bf8_t, float, 32, 32, 64, false, false, false, EQuad> { using Type = WarpGemmMfma_f32_32x32x64_fp8_bf8<EQuad>; };
|
||||
template<> struct Dispatcher<bf8_t, fp8_t, float, 32, 32, 64, false, false, false, EQuad> { using Type = WarpGemmMfma_f32_32x32x64_bf8_fp8<EQuad>; };
|
||||
template<> struct Dispatcher<bf8_t, bf8_t, float, 32, 32, 64, false, false, false, EQuad> { using Type = WarpGemmMfma_f32_32x32x64_bf8_bf8<EQuad>; };
|
||||
|
||||
//WMMA cases
|
||||
template<bool TransposeC> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 16, 16, 16, TransposeC, false> { using Type =WarpGemmWmma_f32_16x16x16_f8_f8<TransposeC>; };
|
||||
template<bool TransposeC> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 16, 16, 16, TransposeC, false> { using Type =WarpGemmWmma_f32_16x16x16_bf8_bf8<TransposeC>; };
|
||||
template<bool TransposeC> struct WarpGemmDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 16, 16, 16, TransposeC, false> { using Type =WarpGemmWmma_f32_16x16x16_f8_bf8<TransposeC>; };
|
||||
template<bool TransposeC> struct WarpGemmDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 16, 16, 16, TransposeC, false> { using Type =WarpGemmWmma_f32_16x16x16_bf8_f8<TransposeC>; };
|
||||
template<bool TransposeC> struct Dispatcher<fp8_t, fp8_t, float, 16, 16, 16, TransposeC, false> { using Type = WarpGemmWmma_f32_16x16x16_f8_f8<TransposeC>; };
|
||||
template<bool TransposeC> struct Dispatcher<bf8_t, bf8_t, float, 16, 16, 16, TransposeC, false> { using Type = WarpGemmWmma_f32_16x16x16_bf8_bf8<TransposeC>; };
|
||||
template<bool TransposeC> struct Dispatcher<fp8_t, bf8_t, float, 16, 16, 16, TransposeC, false> { using Type = WarpGemmWmma_f32_16x16x16_f8_bf8<TransposeC>; };
|
||||
template<bool TransposeC> struct Dispatcher<bf8_t, fp8_t, float, 16, 16, 16, TransposeC, false> { using Type = WarpGemmWmma_f32_16x16x16_bf8_f8<TransposeC>; };
|
||||
|
||||
// int8
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
template<> struct WarpGemmDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 32, 32, 16, false> { using Type = WarpGemmMfma_i32_32x32x16_i8_i8; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 32, 32, 16, true> { using Type = WarpGemmMfma_i32_32x32x16_i8_i8_CTransposed; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 16, 16, 32, false> { using Type = WarpGemmMfma_i32_16x16x32_i8_i8; };
|
||||
template<> struct WarpGemmDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 16, 16, 32, true> { using Type = WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed; };
|
||||
template<> struct Dispatcher<int8_t, int8_t, int32_t, 32, 32, 16, false> { using Type = WarpGemmMfma_i32_32x32x16_i8_i8; };
|
||||
template<> struct Dispatcher<int8_t, int8_t, int32_t, 32, 32, 16, true> { using Type = WarpGemmMfma_i32_32x32x16_i8_i8_CTransposed; };
|
||||
template<> struct Dispatcher<int8_t, int8_t, int32_t, 16, 16, 32, false> { using Type = WarpGemmMfma_i32_16x16x32_i8_i8; };
|
||||
template<> struct Dispatcher<int8_t, int8_t, int32_t, 16, 16, 32, true> { using Type = WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed; };
|
||||
// WMMA cases
|
||||
template<bool TransposeC> struct WarpGemmDispatcher<ck_tile::int8_t, ck_tile::int8_t, int32_t, 16, 16, 16, TransposeC, false> { using Type = WarpGemmWmma_i32_16x16x16_i8_i8<TransposeC>;};
|
||||
template<bool TransposeC> struct Dispatcher<int8_t, int8_t, int32_t, 16, 16, 16, TransposeC, false> { using Type = WarpGemmWmma_i32_16x16x16_i8_i8<TransposeC>;};
|
||||
|
||||
// clang-format on
|
||||
} // namespace warp_gemm_dispatcher
|
||||
} // namespace impl
|
||||
|
||||
template <typename AType,
|
||||
@@ -179,15 +164,16 @@ template <typename AType,
|
||||
bool SwizzleA = false,
|
||||
bool UseStructuredSparsity = false,
|
||||
WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmDispatcher = typename impl::WarpGemmDispatcher<AType,
|
||||
BType,
|
||||
AccType,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
KPerWave,
|
||||
TransposeC,
|
||||
SwizzleA,
|
||||
UseStructuredSparsity,
|
||||
AttrNumAccess>::Type;
|
||||
using WarpGemmDispatcher = typename impl::warp_gemm_dispatcher::Dispatcher< //
|
||||
AType,
|
||||
BType,
|
||||
AccType,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
KPerWave,
|
||||
TransposeC,
|
||||
SwizzleA,
|
||||
UseStructuredSparsity,
|
||||
AttrNumAccess>::Type;
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user