From a36d246cc01cc7a519095738754a13a2eb718200 Mon Sep 17 00:00:00 2001 From: root Date: Sat, 29 Mar 2025 21:05:21 +0000 Subject: [PATCH] [GEMM] fix MFMA configurations --- ...k_gemm_asmem_bsmem_creg_default_policy.hpp | 34 +++++++++---------- 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg_default_policy.hpp b/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg_default_policy.hpp index 011af064af..3894b6cd27 100644 --- a/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg_default_policy.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg_default_policy.hpp @@ -6,10 +6,10 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm.hpp" -#define mfma_m32_n32_k16 1 #define mfma_m32_n32_k8 0 +#define mfma_m32_n32_k16 0 #define mfma_m16_n16_k16 0 -#define mfma_m16_n16_k32 0 +#define mfma_m16_n16_k32 1 namespace ck_tile { @@ -21,21 +21,6 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() { #if mfma_m32_n32_k8 -#pragma message ("mfma m32 n32 k16") ->>>>>>> Stashed changes - if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1); - } - else if constexpr(std::is_same_v && - std::is_same_v && - std::is_same_v) - { - return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1); - } -#elif mfma_m32_n32_k16 #pragma message ("mfma m32 n32 k8") if constexpr(std::is_same_v && std::is_same_v && @@ -49,7 +34,20 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy { return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, 4, 1); } - +#elif mfma_m32_n32_k16 +#pragma message ("mfma m32 n32 k16") + if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1); + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1); + } #elif mfma_m16_n16_k16 #pragma message("mfma m16 n16 k16") if constexpr(std::is_same_v &&