diff --git a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp index f1386b1d92..dc236f0473 100644 --- a/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp +++ b/composable_kernel/include/kernel_algorithm/gridwise_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp @@ -155,7 +155,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer static_assert(in_e_n1_b_n2_block_desc.GetStride(I1) % GemmDataPerReadB == 0, "GemmDataPerReadB alignment requirement is not satisfied"); -#if 1 // debug +#if 1 // input blockwise copy // slice a merged tensor, reorder and copy to a normal tensor // this copy operator already has blockwise offset built-in @@ -198,7 +198,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer Sequence{}, Number{}); -#if 1 // debug +#if 1 // operator for blockwise copy of weight into LDS // slice a tensor, and copy it into another tensor // this copy operator already have blockwise offset built-in @@ -324,10 +324,12 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer #if 1 blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number{}, True); + // blockwise_wei_copy.MoveSlicingWindowOnSourceTensor(I0, Number{}, + // True); p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0); #else - blockwise_in_copy.MoveSrcSlicingWindow({EPerBlock, 0, 0, 0}, true); - blockwise_wei_copy.MoveSrcSlicingWindow({EPerBlock, 0}, true); + blockwise_in_copy.MoveSrcSlicingWindow(Sequence{}, True); + blockwise_wei_copy.MoveSrcSlicingWindow(Sequence{}, True); #endif __syncthreads(); @@ -348,16 +350,17 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer // LDS double buffer: tail { + // even iteration Float p_in_register_buffer[blockwise_in_copy.GetRegisterBufferSize()]; Float p_wei_register_buffer[blockwise_wei_copy.GetRegisterBufferSize()]; -// even iteration #if 1 blockwise_in_copy.MoveSlicingWindowOnSourceTensor(I0, Number{}, True); + // blockwise_wei_copy.MoveSlicingWindowOnSourceTensor(I0, Number{}, True); p_wei_block_on_global += EPerBlock * wei_e_k_global_desc.GetStride(I0); #else - blockwise_in_copy.MoveSrcSlicingWindow({EPerBlock, 0, 0, 0}, true); - blockwise_wei_copy.MoveSrcSlicingWindow({EPerBlock, 0}, true); + blockwise_in_copy.MoveSrcSlicingWindow(Sequence{}, True); + blockwise_wei_copy.MoveSrcSlicingWindow(Sequence{}, True); #endif __syncthreads(); @@ -431,7 +434,7 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer out_k_n1_b_n2_global_merged_desc.GetOffsetFromMultiIndex( k_thread_data_on_global, 0, b_thread_data_on_global, 0); -#if 1 // debug +#if 1 threadwise_generic_tensor_slice_copy_v1( out_n0_n1_n2_k0_k1_k2_h_w_thread_desc, p_out_thread, diff --git a/composable_kernel/include/tensor_description/tensor_coordinate.hpp b/composable_kernel/include/tensor_description/tensor_coordinate.hpp index 709beef171..25c1124755 100644 --- a/composable_kernel/include/tensor_description/tensor_coordinate.hpp +++ b/composable_kernel/include/tensor_description/tensor_coordinate.hpp @@ -125,12 +125,15 @@ struct MergedTensorCoordinate __host__ __device__ constexpr index_t GetOffset() const { return mOffset; } - // step_size should be known at compile time - template + template __host__ __device__ void - MoveOnDimension(IDim, index_t step_size, integral_constant) + MoveOnDimension(IDim idim_, T step_size, integral_constant) { - constexpr auto idim = IDim{}; + constexpr auto idim = idim_; + + // if step_size is known at compile time + static_if::value>{}( + [&](auto) { static_if{}([&](auto) { return; }); }); // update original index static_if{}([&](auto) { diff --git a/composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp b/composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp index fa2466be91..1b597b804d 100644 --- a/composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_generic_tensor_slice_copy.hpp @@ -446,14 +446,18 @@ struct BlockwiseGenericTensorSliceCopy_v2 mThreadwiseStore.Run(p_buffer, p_dst); } - __device__ void MoveSrcSlicingWindow(Array step_sizes, bool positive_direction) + template + __device__ void MoveSrcSlicingWindow(T step_sizes, integral_constant) { - mThreadwiseLoad.MoveSrcSlicingWindow(step_sizes, positive_direction); + mThreadwiseLoad.MoveSrcSlicingWindow(step_sizes, + integral_constant{}); } - __device__ void MoveDstSlicingWindow(Array step_sizes, bool positive_direction) + template + __device__ void MoveDstSlicingWindow(T step_sizes, integral_constant) { - mThreadwiseStore.MoveDstSlicingWindow(step_sizes, positive_direction); + mThreadwiseLoad.MoveDstSlicingWindow(step_sizes, + integral_constant{}); } private: diff --git a/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp b/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp index aa75a7fe6c..48cf24068d 100644 --- a/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy.hpp @@ -216,28 +216,20 @@ struct ThreadwiseGenericTensorSliceCopy_v2 }); } - __device__ void MoveSrcSlicingWindow(Array step_sizes, bool positive_direction) + template + __device__ void MoveSrcSlicingWindow(T step_sizes, integral_constant) { - if(positive_direction) - { + static_if{}([&](auto) { mSrcSliceOrigin += step_sizes; - } - else - { - mSrcSliceOrigin -= step_sizes; - } + }).Else([&](auto) { mSrcSliceOrigin -= step_sizes; }); } - __device__ void MoveDstSlicingWindow(Array step_sizes, bool positive_direction) + template + __device__ void MoveDstSlicingWindow(T step_sizes, integral_constant) { - if(positive_direction) - { - mDstSliceOrigin += step_sizes; - } - else - { + static_if([&](auto) { mDstSliceOrigin += step_sizes; }).Else([&](auto) { mDstSliceOrigin -= step_sizes; - } + }); } // private: diff --git a/composable_kernel/include/utility/functional3.hpp b/composable_kernel/include/utility/functional3.hpp index fc5f8a6bab..73674aa039 100644 --- a/composable_kernel/include/utility/functional3.hpp +++ b/composable_kernel/include/utility/functional3.hpp @@ -8,6 +8,21 @@ namespace ck { +template +struct is_static : integral_constant +{ +}; + +template +struct is_static> : integral_constant +{ +}; + +template +struct is_static> : integral_constant +{ +}; + // RemainLengths: Sequence<...> template struct static_ford_impl diff --git a/driver/src/driver.cpp b/driver/src/driver.cpp index 4a75628952..540f81186c 100644 --- a/driver/src/driver.cpp +++ b/driver/src/driver.cpp @@ -379,7 +379,7 @@ int main(int argc, char* argv[]) #elif 0 device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw( (in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat); -#elif 0 +#elif 1 device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc, in_nchw, wei_kcyx_desc,