wip - continue defining warp level classes which are templated on GfxId

This commit is contained in:
Philip Maybank
2025-07-29 13:14:47 -04:00
parent 1c098fd37a
commit c62fb6c934
4 changed files with 24 additions and 2 deletions

View File

@@ -57,8 +57,8 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
using WG = WarpGemmMfmaDispatcher<ck_tile::half_t,
ck_tile::half_t,
float,
32,
32,
16,
16,
16,
true,
false,
@@ -67,6 +67,7 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
return make_tuple(WG{}, 4, 1);
#endif
}
#if 0
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
std::is_same_v<typename Problem::BDataType, bf16_t> &&
std::is_same_v<typename Problem::CDataType, float>)
@@ -83,6 +84,7 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
wg_attr_num_access>;
return make_tuple(WG{}, 4, 1);
}
#endif
else
{
static_assert(false, "Unsupported data type configuration for GEMM warp execution.");

View File

@@ -84,6 +84,18 @@ using WarpGemmWmmaF16F16F32M16N16K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeWmmaTransposedCDistribution<
WarpGemmAttributeWmmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
#endif
using WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeGenericTransposedCDistribution<
WarpGemmAttributeGenericImplF16F16F32M16N16K16<900, WGAttrCtlEnum::Default_>>>;
using WarpGemmWmmaF16F16F32M16N16K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeGenericTransposedCDistribution<
WarpGemmAttributeGenericImplF16F16F32M16N16K16<1200, WGAttrCtlEnum::Default_>>>;
#if 0
#if defined(__gfx950__)
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution =

View File

@@ -130,3 +130,5 @@ struct WarpGemmAttributeGenericImplF16F16F32M16N16K16
#endif
}
};
} // namespace ck_tile

View File

@@ -34,6 +34,7 @@ struct WarpGemmWmmaDispatcher;
// clang-format off
// fp16
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
#if 0
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaF16F16F32M32N32K8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaF16F16F32M32N32K16<>; };
@@ -42,10 +43,14 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float
using Type = WarpGemmMfmaF16F16F32M32N32K16<WGAttrNumAccessEnum::Double>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, true, false, false, WGAttrNumAccessEnum::Double> {
using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution<WGAttrNumAccessEnum::Double>; };
#endif
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaF16F16F32M16N16K16; };
template<> struct WarpGemmWmmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, false> { using Type = WarpGemmWmmaF16F16F32M16N16K16; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; };
template<> struct WarpGemmWmmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, true> { using Type = WarpGemmWmmaF16F16F32M16N16K16TransposedCDistribution; };
#if 0
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaF16F16F32M16N16K32<>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<>; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false, false, false, WGAttrNumAccessEnum::Double> {
@@ -128,6 +133,7 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_ti
template<> struct WarpGemmMfmaDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 32, 32, 16, true> { using Type = WarpGemmMfma_i32_32x32x16_i8_i8_CTransposed; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 16, 16, 32, false> { using Type = WarpGemmMfma_i32_16x16x32_i8_i8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 16, 16, 32, true> { using Type = WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed; };
#endif
// clang-format on
} // namespace impl