diff --git a/driver/conv.cu b/driver/conv.cu index ee93414eac..898a6b76af 100644 --- a/driver/conv.cu +++ b/driver/conv.cu @@ -85,19 +85,19 @@ auto make_TensorDescriptor(TConstTensorDesc) } template -void host_direct_convolution(const Tensor& in, const Tensor& wei, Tensor& out) +void host_direct_convolution(const Tensor& in_nchw, const Tensor& wei_kcsr, Tensor& out) { auto f = [&](auto n, auto k, auto ho, auto wo) { double v = 0; - for(int c = 0; c < wei.mDesc.GetLengths()[1]; ++c) + for(int c = 0; c < wei_kcsr.mDesc.GetLengths()[1]; ++c) { - for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y) + for(int y = 0; y < wei_kcsr.mDesc.GetLengths()[2]; ++y) { int hi = ho + y; - for(int x = 0; x < wei.mDesc.GetLengths()[3]; ++x) + for(int x = 0; x < wei_kcsr.mDesc.GetLengths()[3]; ++x) { int wi = wo + x; - v += in(n, c, hi, wi) * wei(k, c, y, x); + v += in_nchw(n, c, hi, wi) * wei_kcsr(k, c, y, x); } } } @@ -114,19 +114,21 @@ void host_direct_convolution(const Tensor& in, const Tensor& wei, Tensor -void host_winograd_3x3_convolution(const Tensor& in, const Tensor& wei, Tensor& out) +void host_winograd_3x3_convolution(const Tensor& in_nchw, + const Tensor& wei_kcsr, + Tensor& out) { constexpr std::size_t OutTileSizeH = 2; constexpr std::size_t OutTileSizeW = 2; - std::size_t N = in.mDesc.GetLengths()[0]; - std::size_t C = in.mDesc.GetLengths()[1]; - std::size_t HI = in.mDesc.GetLengths()[2]; - std::size_t WI = in.mDesc.GetLengths()[3]; + std::size_t N = in_nchw.mDesc.GetLengths()[0]; + std::size_t C = in_nchw.mDesc.GetLengths()[1]; + std::size_t HI = in_nchw.mDesc.GetLengths()[2]; + std::size_t WI = in_nchw.mDesc.GetLengths()[3]; - std::size_t K = wei.mDesc.GetLengths()[0]; - std::size_t S = wei.mDesc.GetLengths()[2]; - std::size_t R = wei.mDesc.GetLengths()[3]; + std::size_t K = wei_kcsr.mDesc.GetLengths()[0]; + std::size_t S = wei_kcsr.mDesc.GetLengths()[2]; + std::size_t R = wei_kcsr.mDesc.GetLengths()[3]; std::size_t HO = out.mDesc.GetLengths()[2]; std::size_t WO = out.mDesc.GetLengths()[3]; @@ -150,7 +152,7 @@ void host_winograd_3x3_convolution(const Tensor& in, const Tensor& wei, Te for(int i = 0; i < InTileSizeW; ++i) { std::size_t wi = OutTileSizeW * x + i; - in_hold(n, c, y, x, j, i) = in(n, c, hi, wi); + in_hold(n, c, y, x, j, i) = in_nchw(n, c, hi, wi); } } }; @@ -194,45 +196,49 @@ void host_winograd_3x3_convolution(const Tensor& in, const Tensor& wei, Te }; auto f_wei_transform = [&](auto k, auto c) { - wei_transform(k, c, 0, 0) = wei(k, c, 0, 0); + wei_transform(k, c, 0, 0) = wei_kcsr(k, c, 0, 0); wei_transform(k, c, 0, 1) = - 0.5 * wei(k, c, 0, 0) + 0.5 * wei(k, c, 0, 1) + 0.5 * wei(k, c, 0, 2); + 0.5 * wei_kcsr(k, c, 0, 0) + 0.5 * wei_kcsr(k, c, 0, 1) + 0.5 * wei_kcsr(k, c, 0, 2); wei_transform(k, c, 0, 2) = - 0.5 * wei(k, c, 0, 0) - 0.5 * wei(k, c, 0, 1) + 0.5 * wei(k, c, 0, 2); - wei_transform(k, c, 0, 3) = wei(k, c, 0, 2); + 0.5 * wei_kcsr(k, c, 0, 0) - 0.5 * wei_kcsr(k, c, 0, 1) + 0.5 * wei_kcsr(k, c, 0, 2); + wei_transform(k, c, 0, 3) = wei_kcsr(k, c, 0, 2); wei_transform(k, c, 1, 0) = - 0.5 * wei(k, c, 0, 0) + 0.5 * wei(k, c, 1, 0) + 0.5 * wei(k, c, 2, 0); - wei_transform(k, c, 1, 1) = - 0.25 * wei(k, c, 0, 0) + 0.25 * wei(k, c, 0, 1) + 0.25 * wei(k, c, 0, 2) + - 0.25 * wei(k, c, 1, 0) + 0.25 * wei(k, c, 1, 1) + 0.25 * wei(k, c, 1, 2) + - 0.25 * wei(k, c, 2, 0) + 0.25 * wei(k, c, 2, 1) + 0.25 * wei(k, c, 2, 2); - wei_transform(k, c, 1, 2) = - 0.25 * wei(k, c, 0, 0) - 0.25 * wei(k, c, 0, 1) + 0.25 * wei(k, c, 0, 2) + - 0.25 * wei(k, c, 1, 0) - 0.25 * wei(k, c, 1, 1) + 0.25 * wei(k, c, 1, 2) + - 0.25 * wei(k, c, 2, 0) - 0.25 * wei(k, c, 2, 1) + 0.25 * wei(k, c, 2, 2); + 0.5 * wei_kcsr(k, c, 0, 0) + 0.5 * wei_kcsr(k, c, 1, 0) + 0.5 * wei_kcsr(k, c, 2, 0); + wei_transform(k, c, 1, 1) = 0.25 * wei_kcsr(k, c, 0, 0) + 0.25 * wei_kcsr(k, c, 0, 1) + + 0.25 * wei_kcsr(k, c, 0, 2) + 0.25 * wei_kcsr(k, c, 1, 0) + + 0.25 * wei_kcsr(k, c, 1, 1) + 0.25 * wei_kcsr(k, c, 1, 2) + + 0.25 * wei_kcsr(k, c, 2, 0) + 0.25 * wei_kcsr(k, c, 2, 1) + + 0.25 * wei_kcsr(k, c, 2, 2); + wei_transform(k, c, 1, 2) = 0.25 * wei_kcsr(k, c, 0, 0) - 0.25 * wei_kcsr(k, c, 0, 1) + + 0.25 * wei_kcsr(k, c, 0, 2) + 0.25 * wei_kcsr(k, c, 1, 0) - + 0.25 * wei_kcsr(k, c, 1, 1) + 0.25 * wei_kcsr(k, c, 1, 2) + + 0.25 * wei_kcsr(k, c, 2, 0) - 0.25 * wei_kcsr(k, c, 2, 1) + + 0.25 * wei_kcsr(k, c, 2, 2); wei_transform(k, c, 1, 3) = - 0.5 * wei(k, c, 0, 2) + 0.5 * wei(k, c, 1, 2) + 0.5 * wei(k, c, 2, 2); + 0.5 * wei_kcsr(k, c, 0, 2) + 0.5 * wei_kcsr(k, c, 1, 2) + 0.5 * wei_kcsr(k, c, 2, 2); wei_transform(k, c, 2, 0) = - 0.5 * wei(k, c, 0, 0) - 0.5 * wei(k, c, 1, 0) + 0.5 * wei(k, c, 2, 0); - wei_transform(k, c, 2, 1) = - 0.25 * wei(k, c, 0, 0) + 0.25 * wei(k, c, 0, 1) + 0.25 * wei(k, c, 0, 2) - - 0.25 * wei(k, c, 1, 0) - 0.25 * wei(k, c, 1, 1) - 0.25 * wei(k, c, 1, 2) + - 0.25 * wei(k, c, 2, 0) + 0.25 * wei(k, c, 2, 1) + 0.25 * wei(k, c, 2, 2); - wei_transform(k, c, 2, 2) = - 0.25 * wei(k, c, 0, 0) - 0.25 * wei(k, c, 0, 1) + 0.25 * wei(k, c, 0, 2) - - 0.25 * wei(k, c, 1, 0) + 0.25 * wei(k, c, 1, 1) - 0.25 * wei(k, c, 1, 2) + - 0.25 * wei(k, c, 2, 0) - 0.25 * wei(k, c, 2, 1) + 0.25 * wei(k, c, 2, 2); + 0.5 * wei_kcsr(k, c, 0, 0) - 0.5 * wei_kcsr(k, c, 1, 0) + 0.5 * wei_kcsr(k, c, 2, 0); + wei_transform(k, c, 2, 1) = 0.25 * wei_kcsr(k, c, 0, 0) + 0.25 * wei_kcsr(k, c, 0, 1) + + 0.25 * wei_kcsr(k, c, 0, 2) - 0.25 * wei_kcsr(k, c, 1, 0) - + 0.25 * wei_kcsr(k, c, 1, 1) - 0.25 * wei_kcsr(k, c, 1, 2) + + 0.25 * wei_kcsr(k, c, 2, 0) + 0.25 * wei_kcsr(k, c, 2, 1) + + 0.25 * wei_kcsr(k, c, 2, 2); + wei_transform(k, c, 2, 2) = 0.25 * wei_kcsr(k, c, 0, 0) - 0.25 * wei_kcsr(k, c, 0, 1) + + 0.25 * wei_kcsr(k, c, 0, 2) - 0.25 * wei_kcsr(k, c, 1, 0) + + 0.25 * wei_kcsr(k, c, 1, 1) - 0.25 * wei_kcsr(k, c, 1, 2) + + 0.25 * wei_kcsr(k, c, 2, 0) - 0.25 * wei_kcsr(k, c, 2, 1) + + 0.25 * wei_kcsr(k, c, 2, 2); wei_transform(k, c, 2, 3) = - 0.5 * wei(k, c, 0, 2) - 0.5 * wei(k, c, 1, 2) + 0.5 * wei(k, c, 2, 2); + 0.5 * wei_kcsr(k, c, 0, 2) - 0.5 * wei_kcsr(k, c, 1, 2) + 0.5 * wei_kcsr(k, c, 2, 2); - wei_transform(k, c, 3, 0) = wei(k, c, 2, 0); + wei_transform(k, c, 3, 0) = wei_kcsr(k, c, 2, 0); wei_transform(k, c, 3, 1) = - 0.5 * wei(k, c, 2, 0) + 0.5 * wei(k, c, 2, 1) + 0.5 * wei(k, c, 2, 2); + 0.5 * wei_kcsr(k, c, 2, 0) + 0.5 * wei_kcsr(k, c, 2, 1) + 0.5 * wei_kcsr(k, c, 2, 2); wei_transform(k, c, 3, 2) = - 0.5 * wei(k, c, 2, 0) - 0.5 * wei(k, c, 2, 1) + 0.5 * wei(k, c, 2, 2); - wei_transform(k, c, 3, 3) = wei(k, c, 2, 2); + 0.5 * wei_kcsr(k, c, 2, 0) - 0.5 * wei_kcsr(k, c, 2, 1) + 0.5 * wei_kcsr(k, c, 2, 2); + wei_transform(k, c, 3, 3) = wei_kcsr(k, c, 2, 2); }; auto f_out_transform = [&](auto n, auto k, auto y, auto x) { @@ -366,54 +372,66 @@ int main() constexpr unsigned R = 3; #endif - auto in_desc = make_ConstantTensorDescriptor(Sequence{}); - auto wei_desc = make_ConstantTensorDescriptor(Sequence{}); - auto out_desc = get_convolution_output_default_4d_tensor_descriptor(in_desc, wei_desc); + auto in_nchw_desc = make_ConstantTensorDescriptor(Sequence{}); + auto wei_kcsr_desc = make_ConstantTensorDescriptor(Sequence{}); + auto wei_srck_desc = make_ConstantTensorDescriptor(Sequence{}); + auto out_nkhw_desc = + get_convolution_output_default_4d_tensor_descriptor(in_nchw_desc, wei_kcsr_desc); - ostream_ConstantTensorDescriptor(in_desc, std::cout << "in_desc: "); - ostream_ConstantTensorDescriptor(wei_desc, std::cout << "wei_desc: "); - ostream_ConstantTensorDescriptor(out_desc, std::cout << "out_desc: "); + ostream_ConstantTensorDescriptor(in_nchw_desc, std::cout << "in_nchw_desc: "); + ostream_ConstantTensorDescriptor(wei_kcsr_desc, std::cout << "wei_kcsr_desc: "); + ostream_ConstantTensorDescriptor(wei_srck_desc, std::cout << "wei_srck_desc: "); + ostream_ConstantTensorDescriptor(out_nkhw_desc, std::cout << "out_nkhw_desc: "); - Tensor in(make_TensorDescriptor(in_desc)); - Tensor wei(make_TensorDescriptor(wei_desc)); - Tensor out_host(make_TensorDescriptor(out_desc)); - Tensor out_device(make_TensorDescriptor(out_desc)); + Tensor in_nchw(make_TensorDescriptor(in_nchw_desc)); + Tensor wei_kcsr(make_TensorDescriptor(wei_kcsr_desc)); + Tensor wei_srck(make_TensorDescriptor(wei_srck_desc)); + Tensor out_nkhw_host(make_TensorDescriptor(out_nkhw_desc)); + Tensor out_nkhw_device(make_TensorDescriptor(out_nkhw_desc)); #if 0 std::size_t num_thread = std::thread::hardware_concurrency(); - in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + in_nchw.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + wei_kcsr.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + wei_srck.GenerateTensorValue(GeneratorTensor_1{}, num_thread); #elif 1 std::size_t num_thread = std::thread::hardware_concurrency(); - in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + in_nchw.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + wei_kcsr.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + wei_srck.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); #endif for(int i = 0; i < 40; ++i) { #if 0 - device_direct_convolution_1(in_desc, in, wei_desc, wei, out_desc, out_device); + device_direct_convolution_1(in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device); #elif 0 - device_direct_convolution_2(in_desc, in, wei_desc, wei, out_desc, out_device); + device_direct_convolution_2( + in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device); +#elif 0 + device_implicit_gemm_convolution( + in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device); #elif 1 - device_implicit_gemm_convolution(in_desc, in, wei_desc, wei, out_desc, out_device); + device_implicit_gemm_convolution( + in_nchw_desc, in_nchw, wei_srck_desc, wei_srck, out_nkhw_desc, out_nkhw_device); #elif 0 - device_winograd_convolution(in_desc, in, wei_desc, wei, out_desc, out_device); + device_winograd_convolution( + in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device); #endif } #if 1 - host_winograd_3x3_convolution(in, wei, out_host); - check_error(out_host, out_device); + host_winograd_3x3_convolution(in_nchw, wei_kcsr, out_nkhw_host); + check_error(out_nkhw_host, out_nkhw_device); #elif 0 - host_direct_convolution(in, wei, out_host); - check_error(out_host, out_device); + host_direct_convolution(in_nchw, wei_kcsr, out_nkhw_host); + check_error(out_nkhw_host, out_nkhw_device); #endif #if 0 - LogRange(std::cout << "in : ", in.mData, ",") << std::endl; - LogRange(std::cout << "wei: ", wei.mData, ",") << std::endl; - LogRange(std::cout << "out_host : ", out_host.mData, ",") << std::endl; - LogRange(std::cout << "out_device: ", out_device.mData, ",") << std::endl; + LogRange(std::cout << "in_nchw : ", in_nchw.mData, ",") << std::endl; + LogRange(std::cout << "wei_kcsr: ", wei_kcsr.mData, ",") << std::endl; + LogRange(std::cout << "out_nkhw_host : ", out_nkhw_host.mData, ",") << std::endl; + LogRange(std::cout << "out_nkhw_device: ", out_nkhw_device.mData, ",") << std::endl; #endif } diff --git a/driver/device_implicit_gemm_convolution.cuh b/driver/device_implicit_gemm_convolution.cuh index 2a529e98c2..5eb9521653 100644 --- a/driver/device_implicit_gemm_convolution.cuh +++ b/driver/device_implicit_gemm_convolution.cuh @@ -1,5 +1,6 @@ #pragma once -#include "gridwise_implicit_gemm_convolution.cuh" +#include "gridwise_implicit_gemm_convolution_nchw_kcsr.cuh" +#include "gridwise_implicit_gemm_convolution_nchw_srck.cuh" template void device_implicit_gemm_convolution( @@ -25,7 +26,7 @@ void device_implicit_gemm_convolution( constexpr auto wei_desc = WeiDesc{}; constexpr auto out_desc = OutDesc{}; -#if 1 +#if 0 constexpr unsigned NPerBlock = 2; constexpr unsigned KPerBlock = 64; constexpr unsigned CPerBlock = 4; @@ -39,6 +40,20 @@ void device_implicit_gemm_convolution( constexpr unsigned WoPerThread = 4; constexpr unsigned BlockSize = 256; +#elif 1 + constexpr unsigned NPerBlock = 2; + constexpr unsigned KPerBlock = 32; + constexpr unsigned CPerBlock = 4; + constexpr unsigned HoPerBlock = 2; + constexpr unsigned WoPerBlock = 32; + + constexpr unsigned NPerThread = 2; + constexpr unsigned KPerThread = 4; + constexpr unsigned CPerThread = 2; + constexpr unsigned HoPerThread = 1; + constexpr unsigned WoPerThread = 2; + + constexpr unsigned BlockSize = 128; #endif constexpr unsigned GridSize = @@ -56,27 +71,31 @@ void device_implicit_gemm_convolution( cudaEventCreate(&start); cudaEventRecord(start, 0); - gridwise_implicit_gemm_convolution_nchw_kcsr - <<>>(InDesc{}, - static_cast(in_device_buf.GetDeviceBuffer()), - WeiDesc{}, - static_cast(wei_device_buf.GetDeviceBuffer()), - OutDesc{}, - static_cast(out_device_buf.GetDeviceBuffer())); +#if 0 + gridwise_implicit_gemm_convolution_nchw_kcsr +#elif 1 + gridwise_implicit_gemm_convolution_nchw_srck +#endif + <<>>(InDesc{}, + static_cast(in_device_buf.GetDeviceBuffer()), + WeiDesc{}, + static_cast(wei_device_buf.GetDeviceBuffer()), + OutDesc{}, + static_cast(out_device_buf.GetDeviceBuffer())); cudaEventCreate(&stop); cudaEventRecord(stop, 0); diff --git a/src/include/ConstantTensorDescriptor.cuh b/src/include/ConstantTensorDescriptor.cuh index 6030d51de6..91d5749bdd 100644 --- a/src/include/ConstantTensorDescriptor.cuh +++ b/src/include/ConstantTensorDescriptor.cuh @@ -1,6 +1,22 @@ #pragma once #include "common.cuh" +// this is ugly, only for 4d +template +__host__ __device__ constexpr auto calculate_default_strides(Sequence) +{ + return Sequence{}; +} + +// this is ugly, only for 4d +template +__host__ __device__ constexpr auto calculate_full_lengths(Sequence) +{ + static_assert((S0 % S1 == 0) && (S1 % S2 == 0) && (S2 % S3 == 0), "cannot be evenly divided!"); + + return Sequence<1, S0 / S1, S1 / S2, S2 / S3>{}; +} + template struct ConstantTensorDescriptor { @@ -69,24 +85,14 @@ struct ConstantTensorDescriptor static_assert(nDim == 4, "nDim is not 4"); return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2) + i3 * GetStride(I3); } + + __host__ __device__ constexpr auto Condense() const + { + constexpr auto default_strides = calculate_default_strides(Lengths{}); + return ConstantTensorDescriptor{}; + } }; -// this is ugly, only for 4d -template -__host__ __device__ constexpr auto calculate_default_strides(Sequence) -{ - return Sequence{}; -} - -// this is ugly, only for 4d -template -__host__ __device__ constexpr auto calculate_full_lengths(Sequence) -{ - static_assert((S0 % S1 == 0) && (S1 % S2 == 0) && (S2 % S3 == 0), "cannot be evenly divided!"); - - return Sequence<1, S0 / S1, S1 / S2, S2 / S3>{}; -} - template __host__ __device__ constexpr auto make_ConstantTensorDescriptor(Lengths) { @@ -124,4 +130,4 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s) desc.GetStride(I1), desc.GetStride(I2), desc.GetStride(I3)); -} +} \ No newline at end of file diff --git a/src/include/blockwise_tensor_op.cuh b/src/include/blockwise_tensor_op.cuh index 3635235770..13e2093333 100644 --- a/src/include/blockwise_tensor_op.cuh +++ b/src/include/blockwise_tensor_op.cuh @@ -83,31 +83,31 @@ template -__device__ void -blockwise_4d_tensor_pointwise_operation_binary_reorder(SrcDesc, - Float* const __restrict__ p_src, - DstDesc, - Float* __restrict__ p_dst, - RefDesc, - Reorder, - F f) +__device__ void blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src( + SrcDesc, + Float* const __restrict__ p_src, + DstDesc, + Float* __restrict__ p_dst, + SrcOpLengths, + DstFromSrcReorder, + F f) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; - constexpr unsigned IT0 = Reorder{}.Get(I0); - constexpr unsigned IT1 = Reorder{}.Get(I1); - constexpr unsigned IT2 = Reorder{}.Get(I2); - constexpr unsigned IT3 = Reorder{}.Get(I3); + constexpr unsigned IR0 = DstFromSrcReorder{}.Get(I0); + constexpr unsigned IR1 = DstFromSrcReorder{}.Get(I1); + constexpr unsigned IR2 = DstFromSrcReorder{}.Get(I2); + constexpr unsigned IR3 = DstFromSrcReorder{}.Get(I3); constexpr auto src_desc = SrcDesc{}; constexpr auto dst_desc = DstDesc{}; - constexpr auto ref_desc = RefDesc{}; + constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{}); constexpr unsigned NLoop = ref_desc.GetElementSize() / BlockSize; @@ -133,7 +133,7 @@ blockwise_4d_tensor_pointwise_operation_binary_reorder(SrcDesc, const unsigned aindex = src_desc.Get1dIndex(did[0], did[1], did[2], did[3]); - const unsigned bindex = dst_desc.Get1dIndex(did[IT0], did[IT1], did[IT2], did[IT3]); + const unsigned bindex = dst_desc.Get1dIndex(did[IR0], did[IR1], did[IR2], did[IR3]); f(p_src[aindex], p_dst[bindex]); } @@ -164,7 +164,7 @@ blockwise_4d_tensor_pointwise_operation_binary_reorder(SrcDesc, const unsigned aindex = src_desc.Get1dIndex(did[0], did[1], did[2], did[3]); - const unsigned bindex = dst_desc.Get1dIndex(did[IT0], did[IT1], did[IT2], did[IT3]); + const unsigned bindex = dst_desc.Get1dIndex(did[IR0], did[IR1], did[IR2], did[IR3]); f(p_src[aindex], p_dst[bindex]); } @@ -183,23 +183,28 @@ template -__device__ void blockwise_4d_tensor_copy_reorder( - SrcDesc, Float* const __restrict__ p_src, DstDesc, Float* __restrict__ p_dst, RefDesc, Reorder) + class SrcOpLengths, + class DstFromSrcReorder> +__device__ void +blockwise_4d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc, + Float* const __restrict__ p_src, + DstDesc, + Float* __restrict__ p_dst, + SrcOpLengths, + DstFromSrcReorder) { auto f_copy = [](const Float& src, Float& dst) { dst = src; }; - blockwise_4d_tensor_pointwise_operation_binary_reorder( - SrcDesc{}, p_src, DstDesc{}, p_dst, RefDesc{}, Reorder{}, f_copy); + blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src( + SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, DstFromSrcReorder{}, f_copy); } -template +template __device__ void blockwise_4d_tensor_copy( - SrcDesc, Float* const __restrict__ p_src, DstDesc, Float* __restrict__ p_dst, RefDesc) + SrcDesc, Float* const __restrict__ p_src, DstDesc, Float* __restrict__ p_dst, SrcOpLengths) { - constexpr auto reorder = Sequence<0, 1, 2, 3>{}; + constexpr auto dst_from_src_reorder = Sequence<0, 1, 2, 3>{}; - blockwise_4d_tensor_copy_reorder( - SrcDesc{}, p_src, DstDesc{}, p_dst, RefDesc{}, reorder); + blockwise_4d_tensor_copy_reorder_by_get_dst_from_src( + SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, dst_from_src_reorder); } diff --git a/src/include/common.cuh b/src/include/common.cuh index 0939c13227..ba910853d7 100644 --- a/src/include/common.cuh +++ b/src/include/common.cuh @@ -30,6 +30,8 @@ using Number = Constant; template struct Sequence { + using Type = Sequence; + static constexpr unsigned nDim = sizeof...(Is); const unsigned mData[nDim] = {Is...}; @@ -40,44 +42,24 @@ struct Sequence return mData[I]; } - template - __host__ __device__ constexpr auto Reorder(Number, Number) const + template + __host__ __device__ constexpr auto ReorderByGetNewFromOld(Sequence) const { - constexpr unsigned IR0 = Get(Number{}); - constexpr unsigned IR1 = Get(Number{}); + constexpr auto old_sequence = Type{}; - return Sequence{}; - } + constexpr unsigned NR0 = old_sequence.mData[I0]; + constexpr unsigned NR1 = old_sequence.mData[I1]; + constexpr unsigned NR2 = old_sequence.mData[I2]; + constexpr unsigned NR3 = old_sequence.mData[I3]; - template - __host__ __device__ constexpr auto Reorder(Number, Number, Number) const - { - constexpr unsigned IR0 = Get(Number{}); - constexpr unsigned IR1 = Get(Number{}); - constexpr unsigned IR2 = Get(Number{}); - - return Sequence{}; + return Sequence{}; } template - __host__ __device__ constexpr auto Reorder(Number, Number, Number, Number) const + __host__ __device__ constexpr auto ReorderByPutOldToNew(Sequence) const { - constexpr unsigned IR0 = Get(Number{}); - constexpr unsigned IR1 = Get(Number{}); - constexpr unsigned IR2 = Get(Number{}); - constexpr unsigned IR3 = Get(Number{}); - - return Sequence{}; - } - - template - __host__ __device__ constexpr auto Reorder(Sequence) const - { - constexpr unsigned IR0 = Get(Number{}); - constexpr unsigned IR1 = Get(Number{}); - constexpr unsigned IR2 = Get(Number{}); - constexpr unsigned IR3 = Get(Number{}); - - return Sequence{}; + // don't know how to implement this + printf("Sequence::ReorderByPutOldToNew not implemented"); + assert(false); } }; diff --git a/src/include/gridwise_direct_convolution_2.cuh b/src/include/gridwise_direct_convolution_2.cuh index f61139f116..90ca39aedb 100644 --- a/src/include/gridwise_direct_convolution_2.cuh +++ b/src/include/gridwise_direct_convolution_2.cuh @@ -159,7 +159,7 @@ __global__ void gridwise_direct_convolution_2(InGlobalDesc, wi_block_data_begin), in_block_desc, p_in_block, - in_block_desc); + in_block_desc.GetLengths()); // copy weight tensor to LDS blockwise_4d_tensor_copy( @@ -167,7 +167,7 @@ __global__ void gridwise_direct_convolution_2(InGlobalDesc, p_wei_global + wei_global_desc.Get1dIndex(k_block_data_begin, c_block_data_begin, 0, 0), wei_block_desc, p_wei_block, - wei_block_desc); + wei_block_desc.GetLengths()); __syncthreads(); @@ -209,5 +209,5 @@ __global__ void gridwise_direct_convolution_2(InGlobalDesc, k_block_data_begin + k_thread_data_begin, ho_block_data_begin + ho_thread_data_begin, wo_block_data_begin + wo_thread_data_begin), - out_thread_desc); + out_thread_desc.GetLengths()); } diff --git a/src/include/gridwise_implicit_gemm_convolution.cuh b/src/include/gridwise_implicit_gemm_convolution_nchw_kcsr.cuh similarity index 80% rename from src/include/gridwise_implicit_gemm_convolution.cuh rename to src/include/gridwise_implicit_gemm_convolution_nchw_kcsr.cuh index 6b2bb0fdd7..08f98fce0b 100644 --- a/src/include/gridwise_implicit_gemm_convolution.cuh +++ b/src/include/gridwise_implicit_gemm_convolution_nchw_kcsr.cuh @@ -74,17 +74,39 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc, const unsigned hi_block_data_begin = ho_block_data_begin; const unsigned wi_block_data_begin = wo_block_data_begin; - // tensor view of blockwise input and weight in LDS - constexpr auto wei_srck_block_desc = - make_ConstantTensorDescriptor(Sequence{}); + // tensor view of un-reorderd blockwise input and weight (imaginary) + constexpr auto in_nchw_block_desc = + make_ConstantTensorDescriptor(Sequence{}); - constexpr auto in_chwn_block_desc = - make_ConstantTensorDescriptor(Sequence{}); + constexpr auto wei_kcsr_block_desc = + make_ConstantTensorDescriptor(Sequence{}); + + // tensor view of reordered blockwise input and weight in LDS + constexpr auto reorder_chwn_from_nchw = Sequence<1, 2, 3, 0>{}; + constexpr auto in_chwn_block_desc = make_ConstantTensorDescriptor( + in_nchw_block_desc.GetLengths().ReorderByGetNewFromOld(reorder_chwn_from_nchw)); + + constexpr auto reorder_srck_from_kcsr = Sequence<2, 3, 1, 0>{}; + constexpr auto wei_srck_block_desc = make_ConstantTensorDescriptor( + wei_kcsr_block_desc.GetLengths().ReorderByGetNewFromOld(reorder_srck_from_kcsr)); // tensor view of threadwise output in register constexpr auto out_hkwn_thread_desc = make_ConstantTensorDescriptor(Sequence{}); +#if 0 + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + print_ConstantTensorDescriptor(in_nchw_block_desc, "in_nchw_block_desc"); + print_ConstantTensorDescriptor(in_chwn_block_desc, "in_chwn_block_desc"); + + print_ConstantTensorDescriptor(wei_kcsr_block_desc, "wei_kcsr_block_desc"); + print_ConstantTensorDescriptor(wei_srck_block_desc, "wei_srck_block_desc"); + + print_ConstantTensorDescriptor(out_hkwn_thread_desc, "out_hkwn_thread_desc"); + } +#endif + // a series of blockwise batched GEMM // C_matrix += transpose(A_matrix) * B_matrix // A_matrix and B_matrix saved in LDS, C_matrix saved in register @@ -97,7 +119,7 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc, const auto b_cxwn_block_mtx_desc = make_ConstantMatrixDescriptor( Number{}, Number{}, - Number{}); // constexpr doesn't compile + Number{}); // constexpr doesn't compile const auto c_kxwn_thread_mtx_desc = make_ConstantMatrixDescriptor( Number{}, Number{}); // constexpr doesn't compile @@ -137,11 +159,10 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc, for(unsigned c_block_data_begin = 0; c_block_data_begin < in_nchw_global_desc.GetLength(I1); c_block_data_begin += CPerBlock, __syncthreads()) { +#if 1 // input: global mem to LDS, // convert 4d-tensor in[N,C,Hi,Wi] to matrix in_matrix[C,Hi*Wi*N] - constexpr auto reorder_nchw2chwn = Sequence<3, 0, 1, 2>{}; - - blockwise_4d_tensor_copy_reorder( + blockwise_4d_tensor_copy_reorder_by_get_dst_from_src( in_nchw_global_desc, p_in_global + in_nchw_global_desc.Get1dIndex(n_block_data_begin, c_block_data_begin, @@ -149,21 +170,22 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc, wi_block_data_begin), in_chwn_block_desc, p_in_block, - in_chwn_block_desc, - reorder_nchw2chwn); + in_nchw_block_desc.GetLengths(), + reorder_chwn_from_nchw); +#endif +#if 1 // weight: global mem to LDS, // convert 4d-tensor wei[K,C,S,R] to matrix wei_matrix[S*R*C,K] - constexpr auto reorder_kcsr2srck = Sequence<3, 2, 0, 1>{}; - - blockwise_4d_tensor_copy_reorder( + blockwise_4d_tensor_copy_reorder_by_get_dst_from_src( wei_kcsr_global_desc, p_wei_global + wei_kcsr_global_desc.Get1dIndex(k_block_data_begin, c_block_data_begin, 0, 0), wei_srck_block_desc, p_wei_block, - wei_srck_block_desc, - reorder_kcsr2srck); + wei_kcsr_block_desc.GetLengths(), + reorder_srck_from_kcsr); +#endif __syncthreads(); @@ -187,10 +209,10 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc, const unsigned wo_thread_data_begin = matrix_c_index.row_begin / NPerThread; // output: register to global mem, - // convert matrix out_matrix[Ho*K,Wo*N] to 4d-tensor out[N,K,Ho,Wo] - constexpr auto reorder_hkwn2nkhw = Sequence<2, 1, 3, 0>{}; + // convert out_thread[Ho,K,Wo,N] to out_global[N,K,Ho,Wo] + constexpr auto reorder_nkhw_from_hkwn = Sequence<3, 1, 0, 2>{}; - threadwise_4d_tensor_copy_reorder( + threadwise_4d_tensor_copy_reorder_by_get_dst_from_src( out_hkwn_thread_desc, p_out_thread, out_nkhw_global_desc, @@ -198,6 +220,6 @@ __global__ void gridwise_implicit_gemm_convolution_nchw_kcsr(InGlobalDesc, k_block_data_begin + k_thread_data_begin, ho_block_data_begin + ho_thread_data_begin, wo_block_data_begin + wo_thread_data_begin), - out_hkwn_thread_desc, - reorder_hkwn2nkhw); + out_hkwn_thread_desc.GetLengths(), + reorder_nkhw_from_hkwn); } diff --git a/src/include/gridwise_implicit_gemm_convolution_nchw_srck.cuh b/src/include/gridwise_implicit_gemm_convolution_nchw_srck.cuh new file mode 100644 index 0000000000..4b73e5a1af --- /dev/null +++ b/src/include/gridwise_implicit_gemm_convolution_nchw_srck.cuh @@ -0,0 +1,219 @@ +#pragma once +#include "common.cuh" +#include "ConstantTensorDescriptor.cuh" +#include "ConstantMatrixDescriptor.cuh" +#include "blockwise_tensor_op.cuh" +#include "threadwise_tensor_op.cuh" +#include "gemm.cuh" + +template +__global__ void gridwise_implicit_gemm_convolution_nchw_srck(InGlobalDesc, + Float* const __restrict__ p_in_global, + WeiGlobalDesc, + Float* const __restrict__ p_wei_global, + OutGlobalDesc, + Float* __restrict__ p_out_global) +{ + // NPerThread == NPerBlock, because the format of input in LDS [C,Hi,Wi,N] + // for GEMM trans([C,K]) * [C,Wo*N], we need a thread to do all the "N" + // if we use [C,Hi,N,Wi,N] in LDS, then NPerThread can be different from NPerBlock + constexpr unsigned NPerThread = NPerBlock; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto in_nchw_global_desc = InGlobalDesc{}; + constexpr auto wei_srck_global_desc = WeiGlobalDesc{}; + constexpr auto out_nkhw_global_desc = OutGlobalDesc{}; + + constexpr unsigned S = wei_srck_global_desc.GetLength(I0); + constexpr unsigned R = wei_srck_global_desc.GetLength(I1); + + constexpr unsigned HiPerBlock = HoPerBlock + S - 1; + constexpr unsigned WiPerBlock = WoPerBlock + R - 1; + + // divide block work: NCHW + constexpr unsigned NBlockWork = + (out_nkhw_global_desc.GetLength(I0) + NPerBlock - 1) / NPerBlock; + constexpr unsigned KBlockWork = + (out_nkhw_global_desc.GetLength(I1) + KPerBlock - 1) / KPerBlock; + constexpr unsigned HBlockWork = + (out_nkhw_global_desc.GetLength(I2) + HoPerBlock - 1) / HoPerBlock; + constexpr unsigned WBlockWork = + (out_nkhw_global_desc.GetLength(I3) + WoPerBlock - 1) / WoPerBlock; + + unsigned itmp = get_block_1d_id(); + const unsigned n_block_work_id = itmp / (KBlockWork * HBlockWork * WBlockWork); + itmp -= n_block_work_id * (KBlockWork * HBlockWork * WBlockWork); + const unsigned k_block_work_id = itmp / (HBlockWork * WBlockWork); + itmp -= k_block_work_id * (HBlockWork * WBlockWork); + const unsigned h_block_work_id = itmp / WBlockWork; + const unsigned w_block_work_id = itmp - h_block_work_id * WBlockWork; + + const unsigned n_block_data_begin = n_block_work_id * NPerBlock; + const unsigned k_block_data_begin = k_block_work_id * KPerBlock; + const unsigned ho_block_data_begin = h_block_work_id * HoPerBlock; + const unsigned wo_block_data_begin = w_block_work_id * HoPerBlock; + + const unsigned hi_block_data_begin = ho_block_data_begin; + const unsigned wi_block_data_begin = wo_block_data_begin; + + // tensor view of un-reorderd blockwise input and weight (imaginary) + constexpr auto in_nchw_block_desc = + make_ConstantTensorDescriptor(Sequence{}); + + constexpr auto wei_srck_block_desc = + make_ConstantTensorDescriptor(Sequence{}); + + // tensor view of reordered blockwise input and weight in LDS + constexpr auto reorder_chwn_from_nchw = Sequence<1, 2, 3, 0>{}; + constexpr auto in_chwn_block_desc = make_ConstantTensorDescriptor( + in_nchw_block_desc.GetLengths().ReorderByGetNewFromOld(reorder_chwn_from_nchw)); + + // tensor view of threadwise output in register + constexpr auto out_hkwn_thread_desc = + make_ConstantTensorDescriptor(Sequence{}); + +#if 0 + if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0) + { + print_ConstantTensorDescriptor(in_nchw_block_desc, "in_nchw_block_desc"); + print_ConstantTensorDescriptor(in_chwn_block_desc, "in_chwn_block_desc"); + + print_ConstantTensorDescriptor(wei_kcsr_block_desc, "wei_kcsr_block_desc"); + print_ConstantTensorDescriptor(wei_srck_block_desc, "wei_srck_block_desc"); + + print_ConstantTensorDescriptor(out_hkwn_thread_desc, "out_hkwn_thread_desc"); + } +#endif + + // a series of blockwise batched GEMM + // C_matrix += transpose(A_matrix) * B_matrix + // A_matrix and B_matrix saved in LDS, C_matrix saved in register + // A_matrix[C,K] is a sub-matrix of wei_block[S,R,C,K] + // B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N] + // C_matrix[K,Wo*N] is a sub-matrix of out_block[Ho,K,Wo,N] + const auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, Number{}); // constexpr doesn't compile + + const auto b_cxwn_block_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, + Number{}, + Number{}); // constexpr doesn't compile + + const auto c_kxwn_thread_mtx_desc = make_ConstantMatrixDescriptor( + Number{}, Number{}); // constexpr doesn't compile + + auto f_accum = [](auto& c, auto& ab) { c += ab; }; + + const auto blockwise_batch_gemm = + blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c{}; + + // LDS + constexpr unsigned in_block_size = in_chwn_block_desc.GetElementSpace(); + constexpr unsigned wei_block_size = wei_srck_block_desc.GetElementSpace(); + + __shared__ Float p_in_block[in_block_size]; + __shared__ Float p_wei_block[wei_block_size]; + + // register + Float p_out_thread[out_hkwn_thread_desc.GetElementSpace()]; + + // set threadwise output tensor to 0 + threadwise_4d_tensor_set_zero(out_hkwn_thread_desc, p_out_thread); + + for(unsigned c_block_data_begin = 0; c_block_data_begin < in_nchw_global_desc.GetLength(I1); + c_block_data_begin += CPerBlock, __syncthreads()) + { +#if 1 + // input: global mem to LDS, + // convert 4d-tensor in[N,C,Hi,Wi] to matrix in_matrix[C,Hi*Wi*N] + blockwise_4d_tensor_copy_reorder_by_get_dst_from_src( + in_nchw_global_desc, + p_in_global + in_nchw_global_desc.Get1dIndex(n_block_data_begin, + c_block_data_begin, + hi_block_data_begin, + wi_block_data_begin), + in_chwn_block_desc, + p_in_block, + in_nchw_block_desc.GetLengths(), + reorder_chwn_from_nchw); +#endif + +#if 1 + // weight: global mem to LDS, + blockwise_4d_tensor_copy( + wei_srck_global_desc, + p_wei_global + + wei_srck_global_desc.Get1dIndex(0, 0, c_block_data_begin, k_block_data_begin), + wei_srck_block_desc, + p_wei_block, + wei_srck_block_desc.GetLengths()); +#endif + + __syncthreads(); + + // a series of batched GEMM + for(unsigned s = 0; s < S; ++s) + { + for(unsigned r = 0; r < R; ++r) + { + blockwise_batch_gemm.run(p_wei_block + wei_srck_block_desc.Get1dIndex(s, r, 0, 0), + p_in_block + in_chwn_block_desc.Get1dIndex(0, 0, r, 0), + p_out_thread); + } + } + } + + const auto matrix_c_index = + blockwise_batch_gemm.CalculateThreadMatrixCIndex(get_thread_local_1d_id()); + + const unsigned ho_thread_data_begin = matrix_c_index.batch_begin; + const unsigned k_thread_data_begin = matrix_c_index.col_begin; + const unsigned wo_thread_data_begin = matrix_c_index.row_begin / NPerThread; + + // output: register to global mem, + // convert out_thread[Ho,K,Wo,N] to out_global[N,K,Ho,Wo] + constexpr auto reorder_nkhw_from_hkwn = Sequence<3, 1, 0, 2>{}; + + threadwise_4d_tensor_copy_reorder_by_get_dst_from_src( + out_hkwn_thread_desc, + p_out_thread, + out_nkhw_global_desc, + p_out_global + out_nkhw_global_desc.Get1dIndex(n_block_data_begin, + k_block_data_begin + k_thread_data_begin, + ho_block_data_begin + ho_thread_data_begin, + wo_block_data_begin + wo_thread_data_begin), + out_hkwn_thread_desc.GetLengths(), + reorder_nkhw_from_hkwn); +} diff --git a/src/include/threadwise_direct_convolution.cuh b/src/include/threadwise_direct_convolution.cuh index dfb6901fa7..47694b0b22 100644 --- a/src/include/threadwise_direct_convolution.cuh +++ b/src/include/threadwise_direct_convolution.cuh @@ -101,10 +101,10 @@ __device__ void threadwise_direct_convolution_2(InDesc, Float p_wei_reg[wei_reg_desc.GetElementSpace()]; // copy input tensor into register - threadwise_4d_tensor_copy(in_desc, p_in, in_reg_desc, p_in_reg, in_reg_desc); + threadwise_4d_tensor_copy(in_desc, p_in, in_reg_desc, p_in_reg, in_reg_desc.GetLengths()); // copy input tensor into register - threadwise_4d_tensor_copy(wei_desc, p_wei, wei_reg_desc, p_wei_reg, wei_reg_desc); + threadwise_4d_tensor_copy(wei_desc, p_wei, wei_reg_desc, p_wei_reg, wei_reg_desc.GetLengths()); // do convolution threadwise_direct_convolution_1( @@ -159,14 +159,14 @@ __device__ void threadwise_direct_convolution_3(InDesc, p_in + in_desc.Get1dIndex(0, 0, s, 0), in_reg_desc, p_in_reg, - in_reg_desc); + in_reg_desc.GetLengths()); // read first 1x1 weight threadwise_4d_tensor_copy(wei_desc, p_wei + wei_desc.Get1dIndex(0, 0, s, 0), wei_reg_desc, p_wei_reg, - wei_reg_desc); + wei_reg_desc.GetLengths()); // do first 1x1 conv threadwise_direct_convolution_1( @@ -180,7 +180,7 @@ __device__ void threadwise_direct_convolution_3(InDesc, p_wei + wei_desc.Get1dIndex(0, 0, s, r), wei_reg_desc, p_wei_reg, - wei_reg_desc); + wei_reg_desc.GetLengths()); // shift old input to the left threadwise_4d_tensor_shift_down(in_reg_desc, p_in_reg, I3, Number{}); @@ -192,7 +192,7 @@ __device__ void threadwise_direct_convolution_3(InDesc, in_reg_desc, p_in_reg + in_reg_desc.Get1dIndex(0, 0, 0, in_reg_desc.GetLength(I3) - in_w_new_read), - in_desc_reg_new_read); + in_desc_reg_new_read.GetLengths()); // do 1x1 conv threadwise_direct_convolution_1( @@ -211,11 +211,14 @@ __device__ void threadwise_direct_convolution_3(InDesc, p_wei + wei_desc.Get1dIndex(0, 0, s, r), wei_reg_desc, p_wei_reg, - wei_reg_desc); + wei_reg_desc.GetLengths()); // read new input - threadwise_4d_tensor_copy( - in_desc, p_in + in_desc.Get1dIndex(0, 0, s, r), in_reg_desc, p_in_reg, in_reg_desc); + threadwise_4d_tensor_copy(in_desc, + p_in + in_desc.Get1dIndex(0, 0, s, r), + in_reg_desc, + p_in_reg, + in_reg_desc.GetLengths()); // do 1x1 conv threadwise_direct_convolution_1( diff --git a/src/include/threadwise_tensor_op.cuh b/src/include/threadwise_tensor_op.cuh index fcb769ddc1..d715718af9 100644 --- a/src/include/threadwise_tensor_op.cuh +++ b/src/include/threadwise_tensor_op.cuh @@ -37,29 +37,34 @@ __device__ void threadwise_4d_tensor_pointwise_operation_unary(Desc, Float* __re // TODO: in order to optimize mem access for different mem type, // need to write specialized version -template -__device__ void -threadwise_4d_tensor_pointwise_operation_binary_reorder(SrcDesc, - Float* const __restrict__ p_src, - DstDesc, - Float* __restrict__ p_dst, - RefDesc, - Reorder, - F f) +template +__device__ void threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src( + SrcDesc, + Float* const __restrict__ p_src, + DstDesc, + Float* __restrict__ p_dst, + SrcOpLengths, + DstFromSrcReorder, + F f) { constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; - constexpr unsigned IT0 = Reorder{}.Get(I0); - constexpr unsigned IT1 = Reorder{}.Get(I1); - constexpr unsigned IT2 = Reorder{}.Get(I2); - constexpr unsigned IT3 = Reorder{}.Get(I3); + constexpr unsigned IR0 = DstFromSrcReorder{}.Get(I0); + constexpr unsigned IR1 = DstFromSrcReorder{}.Get(I1); + constexpr unsigned IR2 = DstFromSrcReorder{}.Get(I2); + constexpr unsigned IR3 = DstFromSrcReorder{}.Get(I3); constexpr auto src_desc = SrcDesc{}; constexpr auto dst_desc = DstDesc{}; - constexpr auto ref_desc = RefDesc{}; + constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{}); for(unsigned did0 = 0; did0 < ref_desc.GetLength(I0); ++did0) { @@ -74,7 +79,7 @@ threadwise_4d_tensor_pointwise_operation_binary_reorder(SrcDesc, const unsigned did[4] = {did0, did1, did2, did3}; const unsigned bindex = - dst_desc.Get1dIndex(did[IT0], did[IT1], did[IT2], did[IT3]); + dst_desc.Get1dIndex(did[IR0], did[IR1], did[IR2], did[IR3]); f(p_src[aindex], p_dst[bindex]); } @@ -92,29 +97,29 @@ __device__ void threadwise_4d_tensor_set_zero(Desc, Float* __restrict__ p) Desc{}, p, f_set_zero); } -template -__device__ void threadwise_4d_tensor_copy_reorder( - SrcDesc, Float* const __restrict__ p_src, DstDesc, Float* __restrict__ p_dst, RefDesc, Reorder) +template +__device__ void +threadwise_4d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc, + Float* const __restrict__ p_src, + DstDesc, + Float* __restrict__ p_dst, + SrcOpLengths, + DstFromSrcReorder) { auto f_copy = [](const Float& src, Float& dst) { dst = src; }; - threadwise_4d_tensor_pointwise_operation_binary_reorder( - SrcDesc{}, p_src, DstDesc{}, p_dst, RefDesc{}, Reorder{}, f_copy); + threadwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src( + SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, DstFromSrcReorder{}, f_copy); } -template +template __device__ void threadwise_4d_tensor_copy( - SrcDesc, Float* const __restrict__ p_src, DstDesc, Float* __restrict__ p_dst, RefDesc) + SrcDesc, Float* const __restrict__ p_src, DstDesc, Float* __restrict__ p_dst, SrcOpLengths) { - auto reorder = Sequence<0, 1, 2, 3>{}; + auto dst_from_src_reorder = Sequence<0, 1, 2, 3>{}; - threadwise_4d_tensor_copy_reorder( - SrcDesc{}, p_src, DstDesc{}, p_dst, RefDesc{}, reorder); + threadwise_4d_tensor_copy_reorder_by_get_dst_from_src( + SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, dst_from_src_reorder); } template