From a60cf0d0eeafd087cbc361fad71d727e2dddb794 Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Thu, 10 Oct 2024 09:48:33 +0000 Subject: [PATCH] Use AccDataType for Output of MFMA instruction. --- example/ck_tile/03_gemm/gemm_basic.hpp | 2 +- example/ck_tile/03_gemm/gemm_basic_mem_pipeline.cpp | 1 + include/ck_tile/ops/gemm/block/block_gemm_as_bs_cr.hpp | 6 +++--- .../ops/gemm/block/block_gemm_as_bs_cr_default_policy.hpp | 4 ++-- .../ops/gemm/pipeline/block_gemm_pipeline_problem.hpp | 2 ++ 5 files changed, 9 insertions(+), 6 deletions(-) diff --git a/example/ck_tile/03_gemm/gemm_basic.hpp b/example/ck_tile/03_gemm/gemm_basic.hpp index c2c35a572e..e3b5f86f7b 100644 --- a/example/ck_tile/03_gemm/gemm_basic.hpp +++ b/example/ck_tile/03_gemm/gemm_basic.hpp @@ -18,7 +18,7 @@ struct GemmBasicTypeConfig using ADataType = ck_tile::half_t; using BDataType = ck_tile::half_t; using AccDataType = float; - using CDataType = float; + using CDataType = ck_tile::half_t; // ToDo: Add more bias config to support different categories of GEMM. }; diff --git a/example/ck_tile/03_gemm/gemm_basic_mem_pipeline.cpp b/example/ck_tile/03_gemm/gemm_basic_mem_pipeline.cpp index 1edd77d4ad..85a850ffad 100644 --- a/example/ck_tile/03_gemm/gemm_basic_mem_pipeline.cpp +++ b/example/ck_tile/03_gemm/gemm_basic_mem_pipeline.cpp @@ -91,6 +91,7 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem< ck_tile::UniversalGemmPipelineProblem; using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; static constexpr index_t kBlockSize = Problem::kBlockSize; @@ -31,7 +31,7 @@ struct BlockGemmAsBsCr { static_assert(std::is_same_v && std::is_same_v && - std::is_same_v, + std::is_same_v, "wrong!"); constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}]; @@ -195,7 +195,7 @@ struct BlockGemmAsBsCr constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); - auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); return c_block_tensor; } diff --git a/include/ck_tile/ops/gemm/block/block_gemm_as_bs_cr_default_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_as_bs_cr_default_policy.hpp index 0f56a14255..74b67509e1 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_as_bs_cr_default_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_as_bs_cr_default_policy.hpp @@ -17,7 +17,7 @@ struct BlockGemmAsBsCrDefaultPolicy { if constexpr(std::is_same_v && std::is_same_v && - std::is_same_v) + std::is_same_v) { #if 0 constexpr index_t kBlockSize = Problem::kBlockSize; @@ -45,7 +45,7 @@ struct BlockGemmAsBsCrDefaultPolicy } else if constexpr(std::is_same_v && std::is_same_v && - std::is_same_v) + std::is_same_v) { return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1); } diff --git a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp index 9f506e6279..d34ec51365 100644 --- a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp @@ -42,6 +42,7 @@ struct BlockGemmPipelineProblem template ; using BDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t;