[GEMM] use mfma k8 warp gemm

This commit is contained in:
root
2025-03-17 16:01:04 +00:00
parent 10033c1cdc
commit ff15e2da7a

View File

@@ -15,6 +15,8 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
{
#if 0
#pragma message ("mfma k16")
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
std::is_same_v<typename Problem::BDataType, half_t> &&
std::is_same_v<typename Problem::CDataType, float>)
@@ -27,6 +29,21 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy
{
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1);
}
#else
#pragma message ("mfma k8")
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
std::is_same_v<typename Problem::BDataType, half_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1);
}
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
std::is_same_v<typename Problem::BDataType, bf16_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, 4, 1);
}
#endif
else
{
static_assert(false, "Unsupported data type configuration for GEMM warp execution.");