diff --git a/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg_default_policy.hpp b/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg_default_policy.hpp index 0f374d9f94..011af064af 100644 --- a/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg_default_policy.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg_default_policy.hpp @@ -6,6 +6,11 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm.hpp" +#define mfma_m32_n32_k16 1 +#define mfma_m32_n32_k8 0 +#define mfma_m16_n16_k16 0 +#define mfma_m16_n16_k32 0 + namespace ck_tile { // Default policy for BlockGemmASmemBSmemCReg @@ -15,8 +20,9 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() { -#if 1 -#pragma message ("mfma k16") +#if mfma_m32_n32_k8 +#pragma message ("mfma m32 n32 k16") +>>>>>>> Stashed changes if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) @@ -29,8 +35,8 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy { return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1); } -#else -#pragma message ("mfma k8") +#elif mfma_m32_n32_k16 +#pragma message ("mfma m32 n32 k8") if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) @@ -43,6 +49,35 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy { return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, 4, 1); } + +#elif mfma_m16_n16_k16 +#pragma message("mfma m16 n16 k16") + if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 4, 1); + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return make_tuple(WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}, 4, 1); + } +#elif mfma_m16_n16_k32 +#pragma message("mfma m16 n16 k32") + if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 4, 1); + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return make_tuple(WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution{}, 4, 1); + } #endif else { diff --git a/example/ck_tile/99_toy_example/02_gemm/gemm.cpp b/example/ck_tile/99_toy_example/02_gemm/gemm.cpp index 7aea4b376d..320b94c481 100644 --- a/example/ck_tile/99_toy_example/02_gemm/gemm.cpp +++ b/example/ck_tile/99_toy_example/02_gemm/gemm.cpp @@ -42,9 +42,9 @@ int main(int argc, char* argv[]) if(argc == 5) { verification = std::stoi(argv[1]); - M = std::stoi(argv[1]); - N = std::stoi(argv[2]); - K = std::stoi(argv[3]); + M = std::stoi(argv[2]); + N = std::stoi(argv[3]); + K = std::stoi(argv[4]); } const ck_tile::index_t Lda = K; @@ -114,7 +114,7 @@ int main(int argc, char* argv[]) kGemmKPerBlock>; float ave_time = - ck_tile::launch_kernel(ck_tile::stream_config{nullptr, true}, + ck_tile::launch_kernel(ck_tile::stream_config{nullptr, true, 0, 5, 1000}, ck_tile::make_kernel( gemm_kernel{}, kGridSize,