From 2ded318de8960d97d50ede2bb7b94ecd8b68eda3 Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Wed, 18 Sep 2024 11:32:29 -0700 Subject: [PATCH] Ck tile gemm padding dim (#1516) * Support the N dimension padding * Finished the padding feature for different dimension of K [ROCm/composable_kernel commit: 694c300145799f0e8a477a1de69a0414c997bea7] --- example/ck_tile/03_gemm/gemm_basic.cpp | 10 +++---- .../ck_tile/ops/gemm/kernel/gemm_kernel.hpp | 26 +++++++++++++++---- ...lock_gemm_pipeline_agmem_bgmem_creg_v1.hpp | 4 +++ .../pipeline/block_gemm_pipeline_problem.hpp | 6 ++--- 4 files changed, 33 insertions(+), 13 deletions(-) diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index d0b61612a0..9f790f6acb 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -179,9 +179,9 @@ float invoke_gemm(ck_tile::DeviceMem& a_buf, std::cout << "The overall perfomance of the GEMM with " << "[" << data_type << "]" - << "batch size: " << batch_size << ". m:" << M << ",n:" << N << ", k:" << K - << "is: \n"; - std::cout << "Running time :" << ave_time << "ms, Throughput" << gb_per_sec << "GB/s \n" + << "batch size: " << batch_size << ". m:" << M << ", n:" << N << ", k:" << K + << " is: \n"; + std::cout << "Running time: " << ave_time << "ms, Throughput " << gb_per_sec << "GB/s \n" << std::flush; return ave_time; @@ -235,7 +235,7 @@ int main(int argc, char* argv[]) // The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part. constexpr bool kPadA = true; constexpr bool kPadB = true; - constexpr bool kPadC = false; + constexpr bool kPadC = true; // This part comes from the Codegen constexpr ck_tile::index_t M_Tile = 128; @@ -348,7 +348,7 @@ int main(int argc, char* argv[]) pass_gpu = ck_tile::check_err(c_host_dev, c_host_gpu_ref); - std::cout << "The GPU veification result is:" << (pass_gpu ? "correct" : "fail") + std::cout << "The GPU veification result is: " << (pass_gpu ? "correct" : "fail") << std::flush; } diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index 338adfd3cf..e24d7f9ea0 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -123,14 +123,26 @@ struct GemmKernel } }(); - auto ABlockWindow = make_tile_window( + auto a_pad_view = pad_tensor_view( a_tensor_view, make_tuple(number{}, number{}), + sequence < 0, + GemmPipeline::kPadA ? 1 : 0 > {}); + + auto ABlockWindow = make_tile_window( + a_pad_view, + make_tuple(number{}, number{}), {i_m, 0}); - auto BBlockWindow = make_tile_window( + auto b_pad_view = pad_tensor_view( b_tensor_view, make_tuple(number{}, number{}), + sequence < 0, + GemmPipeline::kPadB ? 1 : 0 > {}); + + auto BBlockWindow = make_tile_window( + b_pad_view, + make_tuple(number{}, number{}), {i_n, 0}); // allocate LDS @@ -163,12 +175,16 @@ struct GemmKernel } }(); - auto CBlockWindow = make_tile_window( + auto c_pad_view = pad_tensor_view( c_tensor_view, make_tuple(number{}, number{}), + sequence < 0, + GemmPipeline::kPadC ? 1 : 0 > {}); + auto CBlockWindow_pad = make_tile_window( + c_pad_view, + make_tuple(number{}, number{}), {i_m, i_n}); - // epilogue. - EpiloguePipeline{}(CBlockWindow, acc); + EpiloguePipeline{}(CBlockWindow_pad, acc); } }; diff --git a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp index 0557143bc8..bec8a204cc 100644 --- a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -29,6 +29,10 @@ struct BlockGemmPipelineAGmemBGmemCRegV1 static constexpr index_t AlignmentB = Problem::AlignmentB; static constexpr index_t AlignmentC = Problem::AlignmentC; + static constexpr bool kPadA = Problem::kPadA; + static constexpr bool kPadB = Problem::kPadB; + static constexpr bool kPadC = Problem::kPadC; + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize() { return ck_tile::integer_divide_ceil( diff --git a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp index acb94f8a68..8dfba08ad7 100644 --- a/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/block_gemm_pipeline_problem.hpp @@ -28,9 +28,9 @@ struct BlockGemmPipelineProblem static constexpr bool kPadB = kPadB_; static constexpr bool kPadC = kPadC_; - static constexpr index_t AlignmentA = kPadA ? VectorLoadSize / sizeof(ADataType) : 1; - static constexpr index_t AlignmentB = kPadB ? VectorLoadSize / sizeof(BDataType) : 1; - static constexpr index_t AlignmentC = kPadC ? VectorLoadSize / sizeof(CDataType) : 1; + static constexpr index_t AlignmentA = kPadA ? 1 : VectorLoadSize / sizeof(ADataType); + static constexpr index_t AlignmentB = kPadB ? 1 : VectorLoadSize / sizeof(BDataType); + static constexpr index_t AlignmentC = kPadC ? 1 : VectorLoadSize / sizeof(CDataType); }; } // namespace ck_tile