From 3fe676171812b1138a49977e86e2a07a6469e63b Mon Sep 17 00:00:00 2001 From: zjing14 Date: Tue, 26 Sep 2023 18:40:00 -0500 Subject: [PATCH] Fixed Gemmv2r3 kpad (#938) * added kpad support into v2r3 * add generic instances * fixed comments * fixed mnk padding * Update device_batched_gemm_xdl.hpp * fixed kpad --------- Co-authored-by: Jing Zhang [ROCm/composable_kernel commit: 48ba6e8a69947c67b0b4e5021178fe3e2f2c638e] --- .../gpu/grid/gridwise_gemm_pipeline_v2.hpp | 4 ++-- .../gpu/grid/gridwise_gemm_xdlops_v2r3.hpp | 9 +++------ 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp index d3d7d5af85..25e1cebdbe 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v2.hpp @@ -9,13 +9,13 @@ namespace ck { struct GridwiseGemmPipeline_v2 { - __host__ __device__ static constexpr bool IsSupported(index_t num_loop) + __host__ __device__ static constexpr bool IsSupported(const index_t num_loop) { // TODO: improve applicability return num_loop % 2 == 0; } - __host__ __device__ static constexpr bool CalculateHasMainLoop(index_t num_loop) + __host__ __device__ static constexpr bool CalculateHasMainLoop(const index_t num_loop) { return (num_loop / 2) > 1; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp index b6c146ae61..e941f96553 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp @@ -175,7 +175,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 return math::integer_divide_ceil(N, NPerBlock) * NPerBlock; } - __host__ static auto CalculateK0(index_t K) { return math::integer_divide_floor(K, K1Value); } + __host__ static auto CalculateK0(index_t K) { return math::integer_divide_ceil(K, K1Value); } // Argument struct Problem @@ -369,9 +369,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 "Invalid tuning param!"); // check gridwise gemm pipeline - const index_t K0 = problem.K / K1Value; - const auto num_k_loop = K0 / K0PerBlock; - + const auto num_k_loop = math::integer_divide_ceil(problem.K0, K0PerBlock); if(!GridwiseGemmPipe::IsSupported(num_k_loop)) { return false; @@ -1026,8 +1024,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext } // check gridwise gemm pipeline - const index_t K0 = problem.K / K1; - const auto num_k_loop = K0 / K0PerBlock; + const auto num_k_loop = math::integer_divide_ceil(problem.K0, K0PerBlock); if(!GridwiseGemmPipe::IsSupported(num_k_loop)) {