From 2cbabbba5444ca3476ebbe64857a1ee4e6388d13 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Fri, 13 Aug 2021 20:55:39 +0000 Subject: [PATCH] use int instead of index_t in kernel wrapper --- ...mplicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp | 46 +++++++++++++------ 1 file changed, 31 insertions(+), 15 deletions(-) diff --git a/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp b/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp index 9661f0e50c..c1208ac3cb 100644 --- a/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp +++ b/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp @@ -62,23 +62,39 @@ constexpr bool HasMainKBlockLoop = static_cast(CK_PARAM_HasMainKBloc constexpr bool HasDoubleTailKBlockLoop = static_cast(CK_PARAM_HasDoubleTailKBlockLoop); extern "C" __global__ void -convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(index_t N, - index_t C, - index_t Hi, - index_t Wi, - index_t K, - index_t Y, - index_t X, - index_t ConvStrideH, - index_t ConvStrideW, - index_t ConvDilationH, - index_t ConvDilationW, - index_t InLeftPadH, - index_t InLeftPadW, - index_t InRightPadH, - index_t InRightPadW, +convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(int N_, + int C_, + int Hi_, + int Wi_, + int K_, + int Y_, + int X_, + int ConvStrideH_, + int ConvStrideW_, + int ConvDilationH_, + int ConvDilationW_, + int InLeftPadH_, + int InLeftPadW_, + int InRightPadH_, + int InRightPadW_, void* p_desc_tuple) { + index_t N = static_cast(N_); + index_t C = static_cast(C_); + index_t Hi = static_cast(Hi_); + index_t Wi = static_cast(Wi_); + index_t K = static_cast(K_); + index_t Y = static_cast(Y_); + index_t X = static_cast(X_); + index_t ConvStrideH = static_cast(ConvStrideH_); + index_t ConvStrideW = static_cast(ConvStrideW_); + index_t ConvDilationH = static_cast(ConvDilationH_); + index_t ConvDilationW = static_cast(ConvDilationW_); + index_t InLeftPadH = static_cast(InLeftPadH_); + index_t InLeftPadW = static_cast(InLeftPadW_); + index_t InRightPadH = static_cast(InRightPadH_); + index_t InRightPadW = static_cast(InRightPadW_); + constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{};