diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buffer.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buffer.hpp index 803f6b1b60..b14720eb88 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buffer.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buffer.hpp @@ -112,7 +112,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf "be violated"); // divide block work by [K, B] - static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % EPerBlock == 0, + static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % (2 * EPerBlock) == 0, "wrong! cannot divide work evenly among block"); constexpr index_t KBlockWork = K / KPerBlock; diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buffer.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buffer.hpp index 43145950e3..2172751e6e 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buffer.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buffer.hpp @@ -96,7 +96,7 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf "be violated"); // divide block work by [K, B] - static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % EPerBlock == 0, + static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % (2 * EPerBlock) == 0, "wrong! cannot divide work evenly among block"); constexpr index_t KBlockWork = K / KPerBlock; diff --git a/driver/src/driver.cpp b/driver/src/driver.cpp index caecfce7fa..710009c72f 100644 --- a/driver/src/driver.cpp +++ b/driver/src/driver.cpp @@ -75,11 +75,11 @@ int main(int argc, char* argv[]) using namespace ck; #if 1 - constexpr index_t N = 256; - constexpr index_t C = 64; + constexpr index_t N = 512; + constexpr index_t C = 16; constexpr index_t HI = 17; constexpr index_t WI = 17; - constexpr index_t K = 256; + constexpr index_t K = 512; constexpr index_t Y = 17; constexpr index_t X = 17;