mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-03 13:48:30 +00:00
wip - continue defining warp level classes which are templated on GfxId
This commit is contained in:
@@ -57,8 +57,8 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
|
||||
using WG = WarpGemmMfmaDispatcher<ck_tile::half_t,
|
||||
ck_tile::half_t,
|
||||
float,
|
||||
32,
|
||||
32,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
true,
|
||||
false,
|
||||
@@ -67,6 +67,7 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
|
||||
return make_tuple(WG{}, 4, 1);
|
||||
#endif
|
||||
}
|
||||
#if 0
|
||||
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
@@ -83,6 +84,7 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy
|
||||
wg_attr_num_access>;
|
||||
return make_tuple(WG{}, 4, 1);
|
||||
}
|
||||
#endif
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported data type configuration for GEMM warp execution.");
|
||||
|
||||
@@ -84,6 +84,18 @@ using WarpGemmWmmaF16F16F32M16N16K16TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAttributeWmmaTransposedCDistribution<
|
||||
WarpGemmAttributeWmmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
#endif
|
||||
|
||||
using WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAttributeGenericTransposedCDistribution<
|
||||
WarpGemmAttributeGenericImplF16F16F32M16N16K16<900, WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
using WarpGemmWmmaF16F16F32M16N16K16TransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAttributeGenericTransposedCDistribution<
|
||||
WarpGemmAttributeGenericImplF16F16F32M16N16K16<1200, WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
#if 0
|
||||
|
||||
#if defined(__gfx950__)
|
||||
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
|
||||
using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution =
|
||||
|
||||
@@ -130,3 +130,5 @@ struct WarpGemmAttributeGenericImplF16F16F32M16N16K16
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -34,6 +34,7 @@ struct WarpGemmWmmaDispatcher;
|
||||
// clang-format off
|
||||
// fp16
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
#if 0
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaF16F16F32M32N32K8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaF16F16F32M32N32K16<>; };
|
||||
@@ -42,10 +43,14 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float
|
||||
using Type = WarpGemmMfmaF16F16F32M32N32K16<WGAttrNumAccessEnum::Double>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, true, false, false, WGAttrNumAccessEnum::Double> {
|
||||
using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution<WGAttrNumAccessEnum::Double>; };
|
||||
#endif
|
||||
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaF16F16F32M16N16K16; };
|
||||
template<> struct WarpGemmWmmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, false> { using Type = WarpGemmWmmaF16F16F32M16N16K16; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; };
|
||||
template<> struct WarpGemmWmmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, true> { using Type = WarpGemmWmmaF16F16F32M16N16K16TransposedCDistribution; };
|
||||
|
||||
#if 0
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaF16F16F32M16N16K32<>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution<>; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false, false, false, WGAttrNumAccessEnum::Double> {
|
||||
@@ -128,6 +133,7 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_ti
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 32, 32, 16, true> { using Type = WarpGemmMfma_i32_32x32x16_i8_i8_CTransposed; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 16, 16, 32, false> { using Type = WarpGemmMfma_i32_16x16x32_i8_i8; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::int8_t, ck_tile::int8_t, ck_tile::int32_t, 16, 16, 32, true> { using Type = WarpGemmMfma_i32_16x16x32_i8_i8_CTransposed; };
|
||||
#endif
|
||||
|
||||
// clang-format on
|
||||
} // namespace impl
|
||||
|
||||
Reference in New Issue
Block a user