diff --git a/example/ck_tile/03_gemm/gemm_basic_mem_pipeline.cpp b/example/ck_tile/03_gemm/gemm_basic_mem_pipeline.cpp index 9ca83fed10..063dd98fa5 100644 --- a/example/ck_tile/03_gemm/gemm_basic_mem_pipeline.cpp +++ b/example/ck_tile/03_gemm/gemm_basic_mem_pipeline.cpp @@ -48,14 +48,10 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) using GemmEpilogue = ck_tile::Default2DEpilogue< ck_tile::Default2DEpilogueProblem>; - using BaseGemmPipeline = - ck_tile::BaseGemmPipelineAgBgCrMem>; + using Traits = ck_tile::TileGemmTraits; + + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem< + ck_tile::GemmPipelineProblem>; const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K); const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); @@ -71,14 +67,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ck_tile::UniversalGemmPipelineProblem>; diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index 290af0f962..b4645db9a3 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -164,7 +164,13 @@ int run_gemm_example(int argc, char* argv[]) c_m_n_gpu_ref.SetZero(); c_m_n_gpu_buf_ref.SetZero(); - ck_tile::reference_gemm_gpu( + ck_tile::reference_gemm_gpu( a_m_k_dev_buf, b_k_n_dev_buf, c_m_n_gpu_buf_ref, M, N, K, stride_A, stride_B, stride_C); c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp index 4c1cac002a..bfa91a689a 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp @@ -18,7 +18,7 @@ struct BlockGemmASmemBSmemCRegV1 using Policy = remove_cvref_t; using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; - using AccDataType = remove_cvref_t; + using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; static constexpr index_t kBlockSize = Problem::kBlockSize; @@ -31,7 +31,7 @@ struct BlockGemmASmemBSmemCRegV1 { static_assert(std::is_same_v && std::is_same_v && - std::is_same_v, + std::is_same_v, "wrong!"); constexpr index_t MPerBlock = ABlockWindow{}.get_window_lengths()[number<0>{}]; @@ -195,7 +195,7 @@ struct BlockGemmASmemBSmemCRegV1 constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); - auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); return c_block_tensor; } diff --git a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp index 7860bb52e5..8dd1d1ec28 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp @@ -17,7 +17,7 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy { if constexpr(std::is_same_v && std::is_same_v && - std::is_same_v) + std::is_same_v) { #if 0 constexpr index_t kBlockSize = Problem::kBlockSize; @@ -45,7 +45,7 @@ struct BlockGemmASmemBSmemCRegV1DefaultPolicy } else if constexpr(std::is_same_v && std::is_same_v && - std::is_same_v) + std::is_same_v) { return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1); } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp index d11d28acd7..a987d99bef 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp @@ -4,7 +4,7 @@ #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" namespace ck_tile { @@ -87,7 +87,7 @@ struct BaseGemmPipelineAgBgCrMem // LocalPreFillStages: 1 // LocalPreFetchStages: 0 // LocalSharedMemoryBuffer: 1 -template +template struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem { using Base = BaseGemmPipelineAgBgCrMem; diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp index 98da1510c7..9d050be2fb 100644 --- a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp @@ -1,27 +1,25 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once -#include "ck_tile/core.hpp" - namespace ck_tile { template + typename ALayout_, + typename BLayout_, + typename CLayout_> struct TileGemmTraits { static constexpr bool kPadA = kPadA_; static constexpr bool kPadB = kPadB_; static constexpr bool kPadC = kPadC_; - using LayoutA = LayoutA_; - using LayoutB = LayoutB_; - using LayoutC = LayoutC_; + using ALayout = ALayout_; + using BLayout = BLayout_; + using CLayout = CLayout_; }; } // namespace ck_tile diff --git a/test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp index d151c3e0af..fbb3b42475 100644 --- a/test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp @@ -69,14 +69,10 @@ class TestCkTileGemmMemPipeline : public ::testing::Test using GemmEpilogue = ck_tile::Default2DEpilogue< ck_tile::Default2DEpilogueProblem>; - using BaseGemmPipeline = - ck_tile::BaseGemmPipelineAgBgCrMem>; + using Traits = ck_tile::TileGemmTraits; + + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem< + ck_tile::GemmPipelineProblem>; const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(args.K); const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); @@ -90,14 +86,8 @@ class TestCkTileGemmMemPipeline : public ::testing::Test ck_tile::UniversalGemmPipelineProblem>;