mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
[GEMM] fix MFMA configurations
This commit is contained in:
@@ -6,10 +6,10 @@
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
|
||||
#define mfma_m32_n32_k16 1
|
||||
#define mfma_m32_n32_k8 0
|
||||
#define mfma_m32_n32_k16 0
|
||||
#define mfma_m16_n16_k16 0
|
||||
#define mfma_m16_n16_k32 0
|
||||
#define mfma_m16_n16_k32 1
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -21,21 +21,6 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
#if mfma_m32_n32_k8
|
||||
#pragma message ("mfma m32 n32 k16")
|
||||
>>>>>>> Stashed changes
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
#elif mfma_m32_n32_k16
|
||||
#pragma message ("mfma m32 n32 k8")
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
@@ -49,7 +34,20 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
|
||||
#elif mfma_m32_n32_k16
|
||||
#pragma message ("mfma m32 n32 k16")
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
#elif mfma_m16_n16_k16
|
||||
#pragma message("mfma m16 n16 k16")
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
|
||||
|
||||
Reference in New Issue
Block a user