diff --git a/driver/device_direct_convolution_2_nchw_kcyx_nkhw.hpp b/driver/device_direct_convolution_2_nchw_kcyx_nkhw.hpp index 602702949e..d91757dc8f 100644 --- a/driver/device_direct_convolution_2_nchw_kcyx_nkhw.hpp +++ b/driver/device_direct_convolution_2_nchw_kcyx_nkhw.hpp @@ -2,6 +2,7 @@ #include #include "device.hpp" #include "gridwise_direct_convolution_2_nchw_kcyx_nkhw.hip.hpp" +#include "gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp" template void device_direct_convolution_2_nchw_kcyx_nkhw(InDesc, @@ -57,27 +58,33 @@ void device_direct_convolution_2_nchw_kcyx_nkhw(InDesc, for(unsigned i = 0; i < nrepeat; ++i) { - float time = launch_kernel(gridwise_direct_convolution_2_nchw_kcyx_nkhw, - dim3(GridSize), - dim3(BlockSize), - static_cast(in_device_buf.GetDeviceBuffer()), - static_cast(wei_device_buf.GetDeviceBuffer()), - static_cast(out_device_buf.GetDeviceBuffer())); + float time = launch_kernel( +#if 0 + gridwise_direct_convolution_2_nchw_kcyx_nkhw +#else + gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw +#endif + , + dim3(GridSize), + dim3(BlockSize), + static_cast(in_device_buf.GetDeviceBuffer()), + static_cast(wei_device_buf.GetDeviceBuffer()), + static_cast(out_device_buf.GetDeviceBuffer())); printf("Elapsed time : %f ms\n", time); usleep(std::min(time * 1000, float(10000))); diff --git a/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp b/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp index c885894165..3edd8253dd 100644 --- a/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp +++ b/driver/device_implicit_gemm_convolution_2_chwn_cyxk_khwn.hpp @@ -211,7 +211,7 @@ void device_implicit_gemm_convolution_2_chwn_cyxk_khwn(InDesc, for(unsigned i = 0; i < nrepeat; ++i) { float time = launch_kernel( -#if 1 +#if 0 gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn #else gridwise_implicit_gemm_convolution_2_chwn_cyxk_khwn_lds_double_buffer diff --git a/src/include/Array.hip.hpp b/src/include/Array.hip.hpp new file mode 100644 index 0000000000..1caab6a4c9 --- /dev/null +++ b/src/include/Array.hip.hpp @@ -0,0 +1,18 @@ +#pragma once + +template +struct Array +{ + using Type = Array; + + static constexpr unsigned nSize = NSize; + + unsigned mData[nSize]; + + template + __host__ __device__ Array(Xs... xs) : mData({static_cast(xs)...}) + { + } + + __host__ __device__ TData operator[](unsigned i) const { return mData[i]; } +}; diff --git a/src/include/ConstantTensorDescriptor.hip.hpp b/src/include/ConstantTensorDescriptor.hip.hpp index 2352b0f50c..2e5d237e81 100644 --- a/src/include/ConstantTensorDescriptor.hip.hpp +++ b/src/include/ConstantTensorDescriptor.hip.hpp @@ -65,8 +65,8 @@ __host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence struct ConstantTensorDescriptor { + using Type = ConstantTensorDescriptor; static constexpr unsigned nDim = Lengths::nDim; - using NDimConstant = Number; __host__ __device__ constexpr ConstantTensorDescriptor() { @@ -91,293 +91,70 @@ struct ConstantTensorDescriptor return Strides{}.Get(Number{}); } + // c++14 doesn't support constexpr lambdas, has to use this trick instead + struct GetElementSize_f + { + template + __host__ __device__ constexpr unsigned operator()(IDim idim) const + { + return Type{}.GetLength(idim); + } + }; + __host__ __device__ constexpr unsigned GetElementSize() const { - static_assert(nDim >= 2 && nDim <= 8, "nDim"); - - if(nDim == 2) + // c++14 doesn't support constexpr lambdas, has to use this trick instead + struct multiply { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; + __host__ __device__ constexpr unsigned operator()(unsigned a, unsigned b) const + { + return a * b; + } + }; - return GetLength(I0) * GetLength(I1); - } - else if(nDim == 3) - { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - - return GetLength(I0) * GetLength(I1) * GetLength(I2); - } - else if(nDim == 4) - { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - return GetLength(I0) * GetLength(I1) * GetLength(I2) * GetLength(I3); - } - else if(nDim == 5) - { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - - return GetLength(I0) * GetLength(I1) * GetLength(I2) * GetLength(I3) * GetLength(I4); - } - else if(nDim == 6) - { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - - return GetLength(I0) * GetLength(I1) * GetLength(I2) * GetLength(I3) * GetLength(I4) * - GetLength(I5); - } - else if(nDim == 7) - { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - constexpr auto I6 = Number<6>{}; - - return GetLength(I0) * GetLength(I1) * GetLength(I2) * GetLength(I3) * GetLength(I4) * - GetLength(I5) * GetLength(I6); - } - else if(nDim == 8) - { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - constexpr auto I6 = Number<6>{}; - constexpr auto I7 = Number<7>{}; - - return GetLength(I0) * GetLength(I1) * GetLength(I2) * GetLength(I3) * GetLength(I4) * - GetLength(I5) * GetLength(I6) * GetLength(I7); - } - else - { - assert(false); - } + return static_const_reduce_n{}(GetElementSize_f{}, multiply{}); } + // c++14 doesn't support constexpr lambdas, has to use this trick instead + struct GetElementSpace_f + { + template + __host__ __device__ constexpr unsigned operator()(IDim idim) const + { + return (Type{}.GetLength(idim) - 1) * Type{}.GetStride(idim); + } + }; + template > __host__ __device__ constexpr unsigned GetElementSpace(Align align = Align{}) const { - static_assert(nDim >= 2 && nDim <= 8, "nDim"); - - constexpr unsigned align_size = align.Get(); - - if(nDim == 2) + // c++14 doesn't support constexpr lambdas, has to use this trick instead + struct add { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; + __host__ __device__ constexpr unsigned operator()(unsigned a, unsigned b) const + { + return a + b; + } + }; - return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) + - align_size; - } - else if(nDim == 3) - { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - - return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) + - (GetLength(I2) - 1) * GetStride(I2) + align_size; - } - else if(nDim == 4) - { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) + - (GetLength(I2) - 1) * GetStride(I2) + (GetLength(I3) - 1) * GetStride(I3) + - align_size; - } - else if(nDim == 5) - { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - - return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) + - (GetLength(I2) - 1) * GetStride(I2) + (GetLength(I3) - 1) * GetStride(I3) + - (GetLength(I4) - 1) * GetStride(I4) + align_size; - } - else if(nDim == 6) - { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - - return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) + - (GetLength(I2) - 1) * GetStride(I2) + (GetLength(I3) - 1) * GetStride(I3) + - (GetLength(I4) - 1) * GetStride(I4) + (GetLength(I5) - 1) * GetStride(I5) + - align_size; - } - else if(nDim == 7) - { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - constexpr auto I6 = Number<6>{}; - - return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) + - (GetLength(I2) - 1) * GetStride(I2) + (GetLength(I3) - 1) * GetStride(I3) + - (GetLength(I4) - 1) * GetStride(I4) + (GetLength(I5) - 1) * GetStride(I5) + - (GetLength(I6) - 1) * GetStride(I6) + align_size; - } - else if(nDim == 8) - { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - constexpr auto I6 = Number<6>{}; - constexpr auto I7 = Number<7>{}; - - return (GetLength(I0) - 1) * GetStride(I0) + (GetLength(I1) - 1) * GetStride(I1) + - (GetLength(I2) - 1) * GetStride(I2) + (GetLength(I3) - 1) * GetStride(I3) + - (GetLength(I4) - 1) * GetStride(I4) + (GetLength(I5) - 1) * GetStride(I5) + - (GetLength(I6) - 1) * GetStride(I6) + (GetLength(I7) - 1) * GetStride(I7) + - align_size; - } + return static_const_reduce_n{}(GetElementSpace_f{}, add{}) + align.Get(); } - // this is ugly, only for 2d - __host__ __device__ unsigned Get1dIndex(unsigned i0, unsigned i1) const + template + __host__ __device__ unsigned Get1dIndex(Is... is) const { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; + static_assert(sizeof...(Is) == nDim, "number of multi-index is wrong"); - static_assert(nDim == 2, "nDim is not 2"); - return i0 * GetStride(I0) + i1 * GetStride(I1); - } + const auto multi_id = Array(is...); - // this is ugly, only for 3d - __host__ __device__ unsigned Get1dIndex(unsigned i0, unsigned i1, unsigned i2) const - { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; + unsigned id = 0; - static_assert(nDim == 3, "nDim is not 3"); - return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2); - } + static_loop_n{}([&](auto IDim) { + constexpr unsigned idim = IDim.Get(); + id += multi_id[idim] * GetStride(IDim); + }); - // this is ugly, only for 4d - __host__ __device__ unsigned - Get1dIndex(unsigned i0, unsigned i1, unsigned i2, unsigned i3) const - { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - static_assert(nDim == 4, "nDim is not 4"); - return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2) + i3 * GetStride(I3); - } - - // this is ugly, only for 5d - __host__ __device__ unsigned - Get1dIndex(unsigned i0, unsigned i1, unsigned i2, unsigned i3, unsigned i4) const - { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - - static_assert(nDim == 5, "nDim is not 5"); - return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2) + i3 * GetStride(I3) + - i4 * GetStride(I4); - } - - // this is ugly, only for 6d - __host__ __device__ unsigned - Get1dIndex(unsigned i0, unsigned i1, unsigned i2, unsigned i3, unsigned i4, unsigned i5) const - { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - - static_assert(nDim == 6, "nDim is not 6"); - return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2) + i3 * GetStride(I3) + - i4 * GetStride(I4) + i5 * GetStride(I5); - } - - // this is ugly, only for 7d - __host__ __device__ unsigned Get1dIndex(unsigned i0, - unsigned i1, - unsigned i2, - unsigned i3, - unsigned i4, - unsigned i5, - unsigned i6) const - { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - constexpr auto I6 = Number<6>{}; - - static_assert(nDim == 7, "nDim is not 7"); - return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2) + i3 * GetStride(I3) + - i4 * GetStride(I4) + i5 * GetStride(I5) + i6 * GetStride(I6); - } - - // this is ugly, only for 8d - __host__ __device__ unsigned Get1dIndex(unsigned i0, - unsigned i1, - unsigned i2, - unsigned i3, - unsigned i4, - unsigned i5, - unsigned i6, - unsigned i7) const - { - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - constexpr auto I6 = Number<6>{}; - constexpr auto I7 = Number<7>{}; - - static_assert(nDim == 8, "nDim is not 8"); - return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2) + i3 * GetStride(I3) + - i4 * GetStride(I4) + i5 * GetStride(I5) + i6 * GetStride(I6) + i7 * GetStride(I7); + return id; } __host__ __device__ constexpr auto Condense() const @@ -385,6 +162,12 @@ struct ConstantTensorDescriptor constexpr auto default_strides = calculate_default_strides(Lengths{}); return ConstantTensorDescriptor{}; } + + template + __host__ __device__ constexpr auto Vectorize(Number, Number) const + { + assert(false); // not implemented + } }; template diff --git a/src/include/Sequence.hip.hpp b/src/include/Sequence.hip.hpp new file mode 100644 index 0000000000..c8ca7a0f24 --- /dev/null +++ b/src/include/Sequence.hip.hpp @@ -0,0 +1,92 @@ +#pragma once +#include "constant_integral.hip.hpp" +#include "functional.hip.hpp" + +template +struct Sequence +{ + using Type = Sequence; + + static constexpr unsigned nDim = sizeof...(Is); + + const unsigned mData[nDim] = {Is...}; + + template + __host__ __device__ constexpr unsigned Get(Number) const + { + return mData[I]; + } + + // this is ugly, only for nDIm = 4 + template + __host__ __device__ constexpr auto ReorderByGetNewFromOld(Sequence) const + { + static_assert(nDim == 4, "nDim != 4"); + + constexpr auto old_sequence = Type{}; + + constexpr unsigned NR0 = old_sequence.mData[I0]; + constexpr unsigned NR1 = old_sequence.mData[I1]; + constexpr unsigned NR2 = old_sequence.mData[I2]; + constexpr unsigned NR3 = old_sequence.mData[I3]; + + return Sequence{}; + } + + template + __host__ __device__ constexpr auto ReorderByPutOldToNew(Sequence) const + { + // don't know how to implement this + printf("Sequence::ReorderByPutOldToNew not implemented"); + assert(false); + } + + template + __host__ __device__ constexpr auto PushBack(Number) const + { + return Sequence{}; + } + + __host__ __device__ constexpr auto PopBack() const; + + template + __host__ __device__ constexpr auto Transform(F f) const + { + return Sequence{}; + } +}; + +template +__host__ __device__ constexpr auto sequence_pop_back(Sequence) +{ + static_assert(sizeof...(Is) >= 1, "empty Sequence!"); + return Sequence{}; +} + +template +__host__ __device__ constexpr auto sequence_sequence_op(Sequence, Sequence, F f) +{ + static_assert(Sequence::nDim == Sequence::nDim, "Dim not the same"); + + return Sequence{}; +} + +template +__host__ __device__ constexpr auto sequence_sequence_add(Sequence, Sequence) +{ + struct add + { + __host__ __device__ constexpr unsigned operator()(unsigned x, unsigned y) const + { + return x + y; + } + }; + + return sequence_sequence_op(Sequence{}, Sequence{}, add{}); +} + +template +__host__ __device__ constexpr auto Sequence::PopBack() const +{ + return sequence_pop_back(Type{}); +} diff --git a/src/include/common.hip.hpp b/src/include/common.hip.hpp index 2df008fcad..f447fce784 100644 --- a/src/include/common.hip.hpp +++ b/src/include/common.hip.hpp @@ -1,4 +1,8 @@ #pragma once +#include "constant_integral.hip.hpp" +#include "Sequence.hip.hpp" +#include "Array.hip.hpp" +#include "functional.hip.hpp" __device__ unsigned get_thread_local_1d_id() { return threadIdx.x; } @@ -91,54 +95,6 @@ struct vector_type }; #endif -template -struct integral_constant -{ - static const T value = N; - - __host__ __device__ constexpr T Get() const { return value; } -}; - -template -using Number = integral_constant; - -template -struct Sequence -{ - using Type = Sequence; - - static constexpr unsigned nDim = sizeof...(Is); - - const unsigned mData[nDim] = {Is...}; - - template - __host__ __device__ constexpr unsigned Get(Number) const - { - return mData[I]; - } - - template - __host__ __device__ constexpr auto ReorderByGetNewFromOld(Sequence) const - { - constexpr auto old_sequence = Type{}; - - constexpr unsigned NR0 = old_sequence.mData[I0]; - constexpr unsigned NR1 = old_sequence.mData[I1]; - constexpr unsigned NR2 = old_sequence.mData[I2]; - constexpr unsigned NR3 = old_sequence.mData[I3]; - - return Sequence{}; - } - - template - __host__ __device__ constexpr auto ReorderByPutOldToNew(Sequence) const - { - // don't know how to implement this - printf("Sequence::ReorderByPutOldToNew not implemented"); - assert(false); - } -}; - template __host__ __device__ constexpr T max(T a, T b) { diff --git a/src/include/constant_integral.hip.hpp b/src/include/constant_integral.hip.hpp new file mode 100644 index 0000000000..70dc69d181 --- /dev/null +++ b/src/include/constant_integral.hip.hpp @@ -0,0 +1,12 @@ +#pragma once + +template +struct integral_constant +{ + static const T value = N; + + __host__ __device__ constexpr T Get() const { return value; } +}; + +template +using Number = integral_constant; diff --git a/src/include/functional.hip.hpp b/src/include/functional.hip.hpp new file mode 100644 index 0000000000..598d5c3c71 --- /dev/null +++ b/src/include/functional.hip.hpp @@ -0,0 +1,49 @@ +#pragma once +#include "constant_integral.hip.hpp" + +template +struct static_loop_n +{ + template + __host__ __device__ void operator()(F f) const + { + static_assert(NLoop > 1, "out-of-range"); + + f(Number{}); + static_loop_n{}(f); + } +}; + +template <> +struct static_loop_n<1> +{ + template + __host__ __device__ void operator()(F f) const + { + f(Number<0>{}); + } +}; + +template +struct static_const_reduce_n +{ + template + __host__ __device__ constexpr auto operator()(F f, Reduce r) const + { + static_assert(NLoop > 1, "out-of-range"); + + constexpr auto a = f(Number{}); + auto b = static_const_reduce_n{}(f, r); // cannot use constexpr here, weird + return r(a, b); + } +}; + +template <> +struct static_const_reduce_n<1> +{ + template + __host__ __device__ constexpr auto operator()(F f, Reduce) const + { + return f(Number<0>{}); + } +}; diff --git a/src/include/gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp b/src/include/gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp new file mode 100644 index 0000000000..cb2a8a5087 --- /dev/null +++ b/src/include/gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp @@ -0,0 +1,195 @@ +#pragma once +#include "common.hip.hpp" +#include "ConstantTensorDescriptor.hip.hpp" +#include "blockwise_4d_tensor_op.hip.hpp" +#include "blockwise_direct_convolution.hip.hpp" +#include "threadwise_4d_tensor_op.hip.hpp" +#include "threadwise_direct_convolution.hip.hpp" + +template +__global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw( + const Float* const __restrict__ p_in_global, + const Float* const __restrict__ p_wei_global, + Float* const __restrict__ p_out_global) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto in_global_desc = InGlobalDesc{}; + constexpr auto wei_global_desc = WeiGlobalDesc{}; + constexpr auto out_global_desc = OutGlobalDesc{}; + + constexpr unsigned Y = wei_global_desc.GetLength(I2); + constexpr unsigned X = wei_global_desc.GetLength(I3); + + constexpr unsigned HiPerBlock = HoPerBlock + Y - 1; + constexpr unsigned WiPerBlock = WoPerBlock + X - 1; + + constexpr auto in_block_desc = + make_ConstantTensorDescriptor(Sequence{}); + + constexpr auto wei_block_desc = + make_ConstantTensorDescriptor(Sequence{}); + + // shared mem + constexpr unsigned in_block_size = in_block_desc.GetElementSpace(); + constexpr unsigned wei_block_size = wei_block_desc.GetElementSpace(); + + __shared__ Float p_in_block[in_block_size]; + __shared__ Float p_wei_block[wei_block_size]; + + // threadwise tensors + constexpr unsigned HiPerThread = HoPerThread + Y - 1; + constexpr unsigned WiPerThread = WoPerThread + X - 1; + + constexpr auto in_thread_block_desc = make_ConstantTensorDescriptor( + Sequence{}, in_block_desc.GetStrides()); + + constexpr auto wei_thread_block_desc = make_ConstantTensorDescriptor( + Sequence{}, wei_block_desc.GetStrides()); + + constexpr auto out_thread_desc = get_convolution_output_default_4d_tensor_descriptor( + in_thread_block_desc, wei_thread_block_desc); + + // register + Float p_out_thread[out_thread_desc.GetElementSpace()]; + + // divide block work + constexpr unsigned NBlockWork = (out_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock; + constexpr unsigned KBlockWork = (out_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock; + constexpr unsigned HBlockWork = (out_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock; + constexpr unsigned WBlockWork = (out_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock; + + const unsigned block_id = blockIdx.x; + + unsigned itmp = block_id; + const unsigned n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork); + itmp -= n_block_work_id * (KBlockWork * HBlockWork * WBlockWork); + const unsigned k_block_work_id = itmp / (HBlockWork * WBlockWork); + itmp -= k_block_work_id * (HBlockWork * WBlockWork); + const unsigned h_block_work_id = itmp / WBlockWork; + const unsigned w_block_work_id = itmp - h_block_work_id * WBlockWork; + + const unsigned n_block_data_begin = n_block_work_id * NPerBlock; + const unsigned k_block_data_begin = k_block_work_id * KPerBlock; + const unsigned ho_block_data_begin = h_block_work_id * HoPerBlock; + const unsigned wo_block_data_begin = w_block_work_id * WoPerBlock; + + const unsigned hi_block_data_begin = ho_block_data_begin; // minus padding + const unsigned wi_block_data_begin = wo_block_data_begin; // minus padding + + // divide thread work + constexpr unsigned NThreadWork = (NPerBlock + NPerThread - 1) / NPerThread; + constexpr unsigned KThreadWork = (KPerBlock + KPerThread - 1) / KPerThread; + constexpr unsigned HThreadWork = (HoPerBlock + HoPerThread - 1) / HoPerThread; + constexpr unsigned WThreadWork = (WoPerBlock + WoPerThread - 1) / WoPerThread; + + const unsigned thread_id = threadIdx.x; + + itmp = thread_id; + const unsigned n_thread_work_id = itmp / (KThreadWork * HThreadWork * WThreadWork); + itmp -= n_thread_work_id * (KThreadWork * HThreadWork * WThreadWork); + const unsigned k_thread_work_id = itmp / (HThreadWork * WThreadWork); + itmp -= k_thread_work_id * (HThreadWork * WThreadWork); + const unsigned h_thread_work_id = itmp / WThreadWork; + const unsigned w_thread_work_id = itmp - h_thread_work_id * WThreadWork; + + const unsigned n_thread_data_begin = n_thread_work_id * NPerThread; + const unsigned k_thread_data_begin = k_thread_work_id * KPerThread; + const unsigned ho_thread_data_begin = h_thread_work_id * HoPerThread; + const unsigned wo_thread_data_begin = w_thread_work_id * WoPerThread; + + const unsigned hi_thread_data_begin = ho_thread_data_begin; + const unsigned wi_thread_data_begin = wo_thread_data_begin; + + constexpr auto blockwise_in_copy = + Blockwise4dTensorCopy1{}; + + constexpr auto blockwise_wei_copy = + Blockwise4dTensorCopy1{}; + + // set threadwise output tensor to 0 + threadwise_4d_tensor_set_zero(out_thread_desc, p_out_thread); + + for(unsigned c_block_data_begin = 0; c_block_data_begin < in_global_desc.GetLength(I1); + c_block_data_begin += CPerBlock, __syncthreads()) + { + // copy input tensor to LDS + blockwise_in_copy.Run(p_in_global + in_global_desc.Get1dIndex(n_block_data_begin, + c_block_data_begin, + hi_block_data_begin, + wi_block_data_begin), + p_in_block); + + // copy weight tensor to LDS + blockwise_wei_copy.Run( + p_wei_global + wei_global_desc.Get1dIndex(k_block_data_begin, c_block_data_begin, 0, 0), + p_wei_block); + + __syncthreads(); + + for(unsigned c_thread_data = 0; c_thread_data < CPerBlock; c_thread_data += CPerThread) + { +// threadwise convolution +#if 1 + threadwise_direct_convolution_2( + in_thread_block_desc, + p_in_block + in_block_desc.Get1dIndex(n_thread_data_begin, + c_thread_data, + hi_thread_data_begin, + wi_thread_data_begin), + wei_thread_block_desc, + p_wei_block + wei_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0), + out_thread_desc, + p_out_thread); +#elif 0 + threadwise_direct_convolution_3( + in_thread_block_desc, + p_in_block + in_block_desc.Get1dIndex(n_thread_data_begin, + c_thread_data, + hi_thread_data_begin, + wi_thread_data_begin), + wei_thread_block_desc, + p_wei_block + wei_block_desc.Get1dIndex(k_thread_data_begin, c_thread_data, 0, 0), + out_thread_desc, + p_out_thread); +#endif + } + } + + // copy output tensor from register to global mem + threadwise_4d_tensor_copy( + out_thread_desc, + p_out_thread, + out_global_desc, + p_out_global + out_global_desc.Get1dIndex(n_block_data_begin + n_thread_data_begin, + k_block_data_begin + k_thread_data_begin, + ho_block_data_begin + ho_thread_data_begin, + wo_block_data_begin + wo_thread_data_begin), + out_thread_desc.GetLengths()); +}