[GEMM] fix MFMA configurations

This commit is contained in:
root
2025-03-29 21:05:21 +00:00
committed by Philip Maybank
parent 15e6f36f66
commit a36d246cc0

View File

@@ -6,10 +6,10 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#define mfma_m32_n32_k16 1
#define mfma_m32_n32_k8 0
#define mfma_m32_n32_k16 0
#define mfma_m16_n16_k16 0
#define mfma_m16_n16_k32 0
#define mfma_m16_n16_k32 1
namespace ck_tile {
@@ -21,21 +21,6 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
{
#if mfma_m32_n32_k8
#pragma message ("mfma m32 n32 k16")
>>>>>>> Stashed changes
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
std::is_same_v<typename Problem::BDataType, half_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1);
}
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
std::is_same_v<typename Problem::BDataType, bf16_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1);
}
#elif mfma_m32_n32_k16
#pragma message ("mfma m32 n32 k8")
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
std::is_same_v<typename Problem::BDataType, half_t> &&
@@ -49,7 +34,20 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy
{
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, 4, 1);
}
#elif mfma_m32_n32_k16
#pragma message ("mfma m32 n32 k16")
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
std::is_same_v<typename Problem::BDataType, half_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1);
}
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
std::is_same_v<typename Problem::BDataType, bf16_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1);
}
#elif mfma_m16_n16_k16
#pragma message("mfma m16 n16 k16")
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&