diff --git a/src/include/gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp b/src/include/gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp index 2b3cb03b78..5901c42e55 100644 --- a/src/include/gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp +++ b/src/include/gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp @@ -27,8 +27,8 @@ template __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( - const typename vector_type::VectorType* const __restrict__ p_in_global, - const typename vector_type::VectorType* const __restrict__ p_wei_global, + const typename vector_type::VectorType* const __restrict__ p_in_vec_global, + const typename vector_type::VectorType* const __restrict__ p_wei_vec_global, Float* const __restrict__ p_out_global) { using scalar_t = Float; @@ -76,25 +76,25 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( ? InBlockCopyDataPerRead : WeiBlockCopyDataPerRead; - __shared__ Float p_in_block[max_align * ((in_block_size + max_align - 1) / max_align)]; - __shared__ Float p_wei_block[max_align * ((wei_block_size + max_align - 1) / max_align)]; + __shared__ vector_t p_in_vec_block[max_align * ((in_block_size + max_align - 1) / max_align)]; + __shared__ vector_t p_wei_vec_block[max_align * ((wei_block_size + max_align - 1) / max_align)]; // threadwise tensors constexpr unsigned HiPerThread = HoPerThread + Y - 1; constexpr unsigned WiPerThread = WoPerThread + X - 1; - constexpr auto in_nchw_thread_block_desc = + constexpr auto in_nchw_vec_thread_block_desc = make_ConstantTensorDescriptor(Sequence{}, in_nchw_vec_block_desc.GetStrides()); - constexpr auto wei_kcyx_thread_block_desc = make_ConstantTensorDescriptor( + constexpr auto wei_kcyx_vec_thread_block_desc = make_ConstantTensorDescriptor( Sequence{}, wei_kcyx_vec_block_desc.GetStrides()); constexpr auto out_nkhw_thread_desc = get_convolution_output_default_4d_tensor_descriptor( - in_nchw_thread_block_desc, wei_kcyx_thread_block_desc); + in_nchw_vec_thread_block_desc, wei_kcyx_vec_thread_block_desc); // register - Float p_out_thread[out_nkhw_thread_desc.GetElementSpace()]; + scalar_t p_out_thread[out_nkhw_thread_desc.GetElementSpace()]; // divide block work constexpr unsigned NBlockWork = @@ -150,7 +150,7 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( constexpr auto blockwise_in_copy = Blockwise4dTensorCopy1