diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp index 6193a1a4d3..d5ea777824 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4_nchw_kcyx_nkhw_lds_double_buffer.hpp @@ -237,11 +237,10 @@ struct GridwiseConvolutionImplicitGemm_v4_nchw_kcyx_nkhw_lds_double_buffer GemmDataPerReadB); constexpr index_t in_block_space = - math::integer_divide_ceil(in_e_n1_b_n2_block_desc.GetElementSpace(), max_align) * - max_align; + math::integer_least_multiple(in_e_n1_b_n2_block_desc.GetElementSpace(), max_align); constexpr index_t wei_block_space = - math::integer_divide_ceil(wei_e_k_block_desc.GetElementSpace(), max_align) * max_align; + math::integer_least_multiple(wei_e_k_block_desc.GetElementSpace(), max_align); __shared__ Float p_in_block_double[2 * in_block_space]; __shared__ Float p_wei_block_double[2 * wei_block_space]; diff --git a/composable_kernel/include/utility/utility.hpp b/composable_kernel/include/utility/utility.hpp index fbf86610b1..e873881561 100644 --- a/composable_kernel/include/utility/utility.hpp +++ b/composable_kernel/include/utility/utility.hpp @@ -54,6 +54,14 @@ __host__ __device__ constexpr T integer_divide_ceil(T a, T b) return (a + b - 1) / b; } +template +__host__ __device__ constexpr T integer_least_multiple(T a, T b) +{ + static_assert(is_same{} || is_same{}, "wrong type"); + + return b * integer_divide_ceil(a, b); +} + template __host__ __device__ constexpr T max(T x) {