[GEMM] Add pragma message for different MFMA options

This commit is contained in:
YC Lin
2025-03-30 20:05:35 +00:00
parent a8027a5b2f
commit 68cd6609eb

View File

@@ -18,6 +18,7 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
{
#if defined(NAIVE_IMPLEMENTATION)
#pragma message ("mfma m32 n32 k8")
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
std::is_same_v<typename Problem::BDataType, half_t> &&
std::is_same_v<typename Problem::CDataType, float>)
@@ -31,6 +32,7 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, 4, 1);
}
#elif defined(USING_MFMA_32x32x_8x2)
#pragma message ("mfma m32 n32 k16")
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
std::is_same_v<typename Problem::BDataType, half_t> &&
std::is_same_v<typename Problem::CDataType, float>)
@@ -43,8 +45,8 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy
{
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1);
}
#elif defined(USING_MFMA_16x16x16)
#pragma message ("mfma m16 n16 k16")
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
std::is_same_v<typename Problem::BDataType, half_t> &&
std::is_same_v<typename Problem::CDataType, float>)
@@ -58,6 +60,7 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy
return make_tuple(WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}, 4, 1);
}
#elif defined(USING_MFMA_16x16x_16x2)
#pragma message ("mfma m16 n16 k32")
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
std::is_same_v<typename Problem::BDataType, half_t> &&
std::is_same_v<typename Problem::CDataType, float>)