From 9de63930c07fc4b80654e53169bbd62fca00a781 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Tue, 18 Jun 2019 03:19:56 -0500 Subject: [PATCH] refactor --- ..._implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp | 5 ++--- composable_kernel/include/utility/utility.hpp | 8 ++++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp index 6193a1a4d3..d5ea777824 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp @@ -237,11 +237,10 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer GemmDataPerReadB); constexpr index_t in_block_space = - math::integer_divide_ceil(in_e_n1_b_n2_block_desc.GetElementSpace(), max_align) * - max_align; + math::integer_least_multiple(in_e_n1_b_n2_block_desc.GetElementSpace(), max_align); constexpr index_t wei_block_space = - math::integer_divide_ceil(wei_e_k_block_desc.GetElementSpace(), max_align) * max_align; + math::integer_least_multiple(wei_e_k_block_desc.GetElementSpace(), max_align); __shared__ Float p_in_block_double[2 * in_block_space]; __shared__ Float p_wei_block_double[2 * wei_block_space]; diff --git a/composable_kernel/include/utility/utility.hpp b/composable_kernel/include/utility/utility.hpp index fbf86610b1..e873881561 100644 --- a/composable_kernel/include/utility/utility.hpp +++ b/composable_kernel/include/utility/utility.hpp @@ -54,6 +54,14 @@ __host__ __device__ constexpr T integer_divide_ceil(T a, T b) return (a + b - 1) / b; } +template +__host__ __device__ constexpr T integer_least_multiple(T a, T b) +{ + static_assert(is_same{} || is_same{}, "wrong type"); + + return b * integer_divide_ceil(a, b); +} + template __host__ __device__ constexpr T max(T x) {