From dfa0213942000be269a1b05058f26893e7e9e56f Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Tue, 30 Oct 2018 11:12:21 -0500 Subject: [PATCH] convolution: init cuda run --- driver/conv.cu | 78 +++-- src/include/device_tensor.cuh | 38 +-- src/include/direct_convolution.cuh | 460 ++++++++++++++++++++++++++++- 3 files changed, 524 insertions(+), 52 deletions(-) diff --git a/driver/conv.cu b/driver/conv.cu index 40851c9dc8..797342d2f7 100644 --- a/driver/conv.cu +++ b/driver/conv.cu @@ -1,9 +1,29 @@ #include +#include +#include #include "nvToolsExt.h" #include "tensor.hpp" #include "device_tensor.cuh" #include "direct_convolution.cuh" +template +struct Generator +{ + T value = 0; + + template + T operator()(Is... is) + { +#if 0 + return value; +#else + std::initializer_list ls = {static_cast(is)...}; + return std::accumulate(ls.begin(), ls.end(), std::size_t(0)); +#endif + } + +}; + template void host_convolution(const Tensor& in, const Tensor& wei, @@ -37,24 +57,39 @@ void host_convolution(const Tensor& in, } template -void device_convolution(const Tensor& in, const Tensor& wei, Tensor& out) +void device_convolution(Tensor& in, Tensor& wei, Tensor& out) { - DeviceTensorDescriptor in_desc_device(in.mDesc); - DeviceTensorDescriptor wei_desc_device(wei.mDesc); - DeviceTensorDescriptor out_desc_device(out.mDesc); + DeviceTensorDescriptor<4> in_desc_device(in.mDesc); + DeviceTensorDescriptor<4> wei_desc_device(wei.mDesc); + DeviceTensorDescriptor<4> out_desc_device(out.mDesc); + + printf("__func__: in_desc_device: %u %u %u %u\n", + in_desc_device.GetLength(0), + in_desc_device.GetLength(1), + in_desc_device.GetLength(2), + in_desc_device.GetLength(3)); std::size_t data_sz = sizeof(T); DeviceMem in_device_buf(data_sz * in.mDesc.GetElementSpace()); DeviceMem wei_device_buf(data_sz * wei.mDesc.GetElementSpace()); DeviceMem out_device_buf(data_sz * out.mDesc.GetElementSpace()); + int num_thread = std::thread::hardware_concurrency(); + +#if 1 + in.GenerateTensorValue(Generator{1}, num_thread); + wei.GenerateTensorValue(Generator{1}, num_thread); +#endif + out.GenerateTensorValue(Generator{0}, num_thread); + in_device_buf.ToDevice(in.mData.data()); wei_device_buf.ToDevice(wei.mData.data()); + out_device_buf.ToDevice(out.mData.data()); - dim3 block_dim(256, 1, 1); + dim3 block_dim(64, 1, 1); dim3 grid_dim(1, 1, 1); - direct_convolution + gridwise_convolution <<>>(in_desc_device, static_cast(in_device_buf.GetDeviceBuffer()), wei_desc_device, @@ -65,23 +100,20 @@ void device_convolution(const Tensor& in, const Tensor& wei, Tensor& ou out_device_buf.FromDevice(out.mData.data()); } -template -struct Generator -{ - - template - T operator()(Is... is) - { - return 1; - } -}; - int main() { #if 0 - Tensor in({3, 16, 128, 128}); + Tensor in({3, 16, 130, 130}); Tensor wei({4, 16, 3, 3}); - Tensor out_host({3, 4, 126, 126}); + Tensor out_host({3, 4, 128, 128}); +#elif 0 + Tensor in({1, 1, 130, 130}); + Tensor wei({1, 1, 3, 3}); + Tensor out_host({1, 1, 128, 128}); +#elif 1 + Tensor in({1, 1, 18, 18}); + Tensor wei({1, 1, 3, 3}); + Tensor out_host({1, 1, 16, 16}); #else Tensor in({1, 1, 4, 4}); Tensor wei({1, 1, 3, 3}); @@ -93,16 +125,16 @@ int main() std::cout << __func__ << ": num_thread " << num_thread << std::endl; - in.GenerateTensorValue(Generator{}, num_thread); - wei.GenerateTensorValue(Generator{}, num_thread); + in.GenerateTensorValue(Generator{1}, num_thread); + wei.GenerateTensorValue(Generator{1}, num_thread); - host_convolution(in, wei, out_host, num_thread); + //host_convolution(in, wei, out_host, num_thread); device_convolution(in, wei, out_device); std::cout << __func__ << ": done" << std::endl; LogRange(std::cout, in.mData, ",") << std::endl; LogRange(std::cout, wei.mData, ",") << std::endl; - LogRange(std::cout, out_host.mData, ",") << std::endl; + //LogRange(std::cout, out_host.mData, ",") << std::endl; LogRange(std::cout, out_device.mData, ",") << std::endl; } diff --git a/src/include/device_tensor.cuh b/src/include/device_tensor.cuh index 56338452f6..de7cad8149 100644 --- a/src/include/device_tensor.cuh +++ b/src/include/device_tensor.cuh @@ -1,37 +1,31 @@ #pragma once +#include #include "helper_cuda.h" #include "tensor.hpp" +template struct DeviceTensorDescriptor { - DeviceTensorDescriptor() = delete; + __host__ __device__ DeviceTensorDescriptor() = default; __host__ DeviceTensorDescriptor(const TensorDescriptor& host_desc) - : mDataType(host_desc.GetDataType()), mDim(host_desc.GetDimension()) { - std::size_t data_sz = host_desc.GetDataType() == DataType_t::Float ? 4 : 2; - - checkCudaErrors(cudaMalloc(&mpLengths, data_sz * mDim)); - checkCudaErrors(cudaMalloc(&mpStrides, data_sz * mDim)); - - checkCudaErrors(cudaMemcpy( - mpLengths, host_desc.GetLengths().data(), data_sz * mDim, cudaMemcpyHostToDevice)); - checkCudaErrors(cudaMemcpy( - mpStrides, host_desc.GetStrides().data(), data_sz * mDim, cudaMemcpyHostToDevice)); + assert(NDim == host_desc.GetDimension()); + std::copy(host_desc.GetLengths().begin(), host_desc.GetLengths().end(), mpLengths); + std::copy(host_desc.GetStrides().begin(), host_desc.GetStrides().end(), mpStrides); } - __host__ ~DeviceTensorDescriptor() + __host__ __device__ unsigned GetLength(unsigned i) const { return mpLengths[i]; } + + __host__ __device__ unsigned long GetStride(unsigned i) const { return mpStrides[i]; } + + // this is ugly + __host__ __device__ unsigned long + Get1dIndex(unsigned n, unsigned c, unsigned h, unsigned w) const { -#if 0 - if(mpLengths != nullptr) - checkCudaErrors(cudaFree(mpLengths)); - if(mpStrides != nullptr) - checkCudaErrors(cudaFree(mpStrides)); -#endif + return n * mpStrides[0] + c * mpStrides[1] + h * mpStrides[2] + w * mpStrides[3]; } - DataType_t mDataType; - unsigned long mDim; - unsigned long* mpLengths = nullptr; - unsigned long* mpStrides = nullptr; + unsigned mpLengths[NDim]; + unsigned long mpStrides[NDim]; }; diff --git a/src/include/direct_convolution.cuh b/src/include/direct_convolution.cuh index 8315dcbb55..067e8c866a 100644 --- a/src/include/direct_convolution.cuh +++ b/src/include/direct_convolution.cuh @@ -1,12 +1,458 @@ #pragma once #include "device_tensor.cuh" -template -__global__ void direct_convolution(DeviceTensorDescriptor in_desc, - TFloat* const p_in, - DeviceTensorDescriptor wei_desc, - TFloat* const p_wei, - DeviceTensorDescriptor out_desc, - TFloat* p_out) +template +__device__ void blockwise_4d_tensor_op(const DeviceTensorDescriptor<4>& src_desc, + TFloat* const __restrict__ p_src, + const DeviceTensorDescriptor<4>& dst_desc, + TFloat* __restrict__ p_dst, + F f) { +#if 1 + if(threadIdx.x < 100) + { + printf("====== blockwise_4d_tensor_op: \t" + "threadIdx.x %u, p_src[threadIdx.x] %f, p_dst[threadIdx.x] %f\n", + threadIdx.x, p_src[threadIdx.x], p_dst[threadIdx.x]); + } +#endif + + constexpr unsigned NWorkStride3 = 1; + constexpr unsigned NWorkStride2 = NWorkLen3 * NWorkStride3; + constexpr unsigned NWorkStride1 = NWorkLen2 * NWorkStride2; + constexpr unsigned NWorkStride0 = NWorkLen1 * NWorkStride1; + + unsigned itmp = + threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * (blockDim.y * blockDim.x); + + const unsigned did0_begin = itmp / NWorkStride0; + + itmp -= did0_begin * NWorkStride0; + + const unsigned did1_begin = itmp / NWorkStride1; + + itmp -= did1_begin * NWorkStride1; + + const unsigned did2_begin = itmp / NWorkStride2; + + itmp -= did2_begin * NWorkStride2; + + const unsigned did3_begin = itmp / NWorkStride3; + + for(unsigned did0 = did0_begin; did0 < src_desc.GetLength(0); did0 += NWorkLen0) + { + for(unsigned did1 = did1_begin; did1 < src_desc.GetLength(1); did1 += NWorkLen1) + { + for(unsigned did2 = did2_begin; did2 < src_desc.GetLength(2); did2 += NWorkLen2) + { + for(unsigned did3 = did3_begin; did3 < src_desc.GetLength(3); did3 += NWorkLen3) + { + const unsigned sindex = + src_desc.GetStride(0) * did0 + src_desc.GetStride(1) * did1 + + src_desc.GetStride(2) * did2 + src_desc.GetStride(3) * did3; + + const unsigned dindex = + dst_desc.GetStride(0) * did0 + dst_desc.GetStride(1) * did1 + + dst_desc.GetStride(2) * did2 + dst_desc.GetStride(3) * did3; + + f(p_dst[dindex], p_src[sindex]); + +#if 1 + printf("thread id %u, dindex %u, p_dst[dindex] %f, sindex %u, p_src[sindex] %f\n", + threadIdx.x, dindex, p_dst[dindex], sindex, p_src[sindex]); +#endif + + } + } + } + } +} + +template +__device__ void threadwise_4d_tensor_op(const DeviceTensorDescriptor<4>& src_desc, + TFloat* const __restrict__ p_src, + const DeviceTensorDescriptor<4>& dst_desc, + TFloat* __restrict__ p_dst, + F f) +{ + for(unsigned did0 = 0; did0 < src_desc.GetLength(0); ++did0) + { + for(unsigned did1 = 0; did1 < src_desc.GetLength(1); ++did1) + { + for(unsigned did2 = 0; did2 < src_desc.GetLength(2); ++did2) + { + for(unsigned did3 = 0; did3 < src_desc.GetLength(3); ++did3) + { + const unsigned sindex = + src_desc.GetStride(0) * did0 + src_desc.GetStride(1) * did1 + + src_desc.GetStride(2) * did2 + src_desc.GetStride(3) * did3; + + const unsigned dindex = + dst_desc.GetStride(0) * did0 + dst_desc.GetStride(1) * did1 + + dst_desc.GetStride(2) * did2 + dst_desc.GetStride(3) * did3; + + f(p_dst[dindex], p_src[sindex]); + } + } + } + } +} + +template +__device__ void threadwise_direct_convolution(const DeviceTensorDescriptor<4>& in_desc, + TFloat* const __restrict__ p_in, + const DeviceTensorDescriptor<4>& wei_desc, + TFloat* const __restrict__ p_wei, + const DeviceTensorDescriptor<4>& out_desc, + TFloat* __restrict__ p_out) +{ + for(unsigned n = 0; n < out_desc.GetLength(0); ++n) + { + for(unsigned k = 0; k < out_desc.GetLength(1); ++k) + { + for(unsigned ho = 0; ho < out_desc.GetLength(2); ++ho) + { + for(unsigned wo = 0; wo < out_desc.GetLength(3); ++wo) + { + for(unsigned c = 0; c < wei_desc.GetLength(1); ++c) + { + for(unsigned s = 0; s < wei_desc.GetLength(2); ++s) + { + for(unsigned r = 0; r < wei_desc.GetLength(3); ++r) + { + const unsigned hi = ho + s; + const unsigned wi = wo + r; + + const unsigned in_index = + in_desc.GetStride(0) * n + in_desc.GetStride(1) * c + + in_desc.GetStride(2) * hi + in_desc.GetStride(3) * wi; + + const unsigned wei_index = + wei_desc.GetStride(0) * k + wei_desc.GetStride(1) * c + + wei_desc.GetStride(2) * s + in_desc.GetStride(3) * r; + + const unsigned out_index = + out_desc.GetStride(0) * n + out_desc.GetStride(1) * k + + out_desc.GetStride(2) * ho + out_desc.GetStride(3) * wo; + + p_out[out_index] += p_wei[wei_index] * p_in[in_index]; + +#if 1 + if(threadIdx.x == 0 ) + { + printf("====== 5: \t" + "out_index %u, p_out[out_index] %f, \t" + "wei_index %u, p_wei[wei_index] %f, \t" + "in_index %u, p_in[in_index] %f\n", + out_index, p_out[out_index], + wei_index, p_wei[wei_index], + in_index, p_in[in_index]); + } +#endif + } + } + } + } + } + } + } + + +} + +template +__device__ void blockwise_convolution(const DeviceTensorDescriptor<4>& in_desc, + TFloat* const __restrict__ p_in, + const DeviceTensorDescriptor<4>& wei_desc, + TFloat* const __restrict__ p_wei, + const DeviceTensorDescriptor<4>& out_desc, + TFloat* __restrict__ p_out) +{ + // for now, one thread do 1 N and 1 K + DeviceTensorDescriptor<4> wei_thread_desc; + wei_thread_desc.mpLengths[0] = 1; + wei_thread_desc.mpLengths[1] = CPerBlockLoop; + wei_thread_desc.mpLengths[2] = S; + wei_thread_desc.mpLengths[3] = R; + wei_thread_desc.mpStrides[3] = 1; + wei_thread_desc.mpStrides[2] = wei_thread_desc.GetLength(3) * wei_thread_desc.GetStride(3); + wei_thread_desc.mpStrides[1] = wei_thread_desc.GetLength(2) * wei_thread_desc.GetStride(2); + wei_thread_desc.mpStrides[0] = wei_thread_desc.GetLength(1) * wei_thread_desc.GetStride(1); + + DeviceTensorDescriptor<4> out_thread_desc; + out_thread_desc.mpLengths[0] = 1; + out_thread_desc.mpLengths[1] = 1; + out_thread_desc.mpLengths[2] = OutTileSizeH; + out_thread_desc.mpLengths[3] = OutTileSizeW; + out_thread_desc.mpStrides[3] = 1; + out_thread_desc.mpStrides[2] = out_thread_desc.GetLength(3) * out_thread_desc.GetStride(3); + out_thread_desc.mpStrides[1] = out_thread_desc.GetLength(2) * out_thread_desc.GetStride(2); + out_thread_desc.mpStrides[0] = out_thread_desc.GetLength(1) * out_thread_desc.GetStride(1); + + DeviceTensorDescriptor<4> in_thread_desc; + in_thread_desc.mpLengths[0] = 1; + in_thread_desc.mpLengths[1] = CPerBlockLoop; + in_thread_desc.mpLengths[2] = OutTileSizeH + S - 1; + in_thread_desc.mpLengths[3] = OutTileSizeW + R - 1; + in_thread_desc.mpStrides[3] = 1; + in_thread_desc.mpStrides[2] = in_thread_desc.GetLength(3) * in_thread_desc.GetStride(3); + in_thread_desc.mpStrides[1] = in_thread_desc.GetLength(2) * in_thread_desc.GetStride(2); + in_thread_desc.mpStrides[0] = in_thread_desc.GetLength(1) * in_thread_desc.GetStride(1); + + const unsigned thread_sz = blockDim.x * blockDim.y * blockDim.z; + + const unsigned thread_id = + threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * (blockDim.y * blockDim.x); + + for(unsigned thread_work_id = thread_id; + thread_work_id < NPerBlock * KPerBlock * YPerBlock * XPerBlock; + thread_work_id += thread_sz) + { + unsigned itmp = thread_work_id; + unsigned n_thread_work_id = itmp / (KPerBlock * YPerBlock * XPerBlock); + itmp -= n_thread_work_id * (KPerBlock * YPerBlock * XPerBlock); + unsigned k_thread_work_id = itmp / (YPerBlock * XPerBlock); + itmp -= k_thread_work_id * (YPerBlock * XPerBlock); + unsigned y_thread_work_id = itmp / XPerBlock; + unsigned x_thread_work_id = itmp - y_thread_work_id * XPerBlock; + + unsigned n_thread_work_begin = n_thread_work_id * 1; + unsigned k_thread_work_begin = k_thread_work_id * 1; + unsigned ho_thread_work_begin = y_thread_work_id * OutTileSizeH; + unsigned wo_thread_work_begin = x_thread_work_id * OutTileSizeW; + + unsigned hi_thread_work_begin = ho_thread_work_begin; // minus padding + unsigned wi_thread_work_begin = wo_thread_work_begin; // minus padding + + TFloat p_in_thread[1 * CPerBlockLoop * InTileSizeH * InTileSizeW]; + TFloat p_wei_thread[1 * CPerBlockLoop * S * R]; + TFloat p_out_thread[1 * 1 * OutTileSizeH * OutTileSizeW]; + + auto f_copy = [](const TFloat& src, TFloat& dst) { dst = src; }; + + // copy input tensor into register + threadwise_4d_tensor_op( + in_desc, + p_in + in_desc.Get1dIndex( + n_thread_work_begin, 0, hi_thread_work_begin, wi_thread_work_begin), + in_thread_desc, + p_in_thread, + f_copy); + + // copy weight tensor into register + threadwise_4d_tensor_op( + wei_desc, + p_wei + wei_thread_desc.Get1dIndex(k_thread_work_begin, 0, 0, 0), + wei_thread_desc, + p_wei_thread, + f_copy); + + // copy output tensor into register + threadwise_4d_tensor_op( + out_desc, + p_out + out_desc.Get1dIndex(n_thread_work_begin, + k_thread_work_begin, + ho_thread_work_begin, + wo_thread_work_begin), + out_thread_desc, + p_out_thread, + f_copy); + + // threadwise convolution + threadwise_direct_convolution(in_thread_desc, + p_in_thread, + wei_thread_desc, + p_wei_thread, + out_thread_desc, + p_out_thread); + + // accumulate output tensor into device mem + threadwise_4d_tensor_op( + out_thread_desc, + p_out_thread, + out_desc, + p_out + out_desc.Get1dIndex(n_thread_work_begin, + k_thread_work_begin, + ho_thread_work_begin, + wo_thread_work_begin), + f_copy); + } +} + +template +__global__ void gridwise_convolution(const DeviceTensorDescriptor<4> in_desc, + TFloat* const __restrict__ p_in, + const DeviceTensorDescriptor<4> wei_desc, + TFloat* const __restrict__ p_wei, + const DeviceTensorDescriptor<4> out_desc, + TFloat* __restrict__ p_out) +{ +#if 1 + if(threadIdx.x < 100) + { + printf("====== 0: \t" + "threadIdx.x %u, p_in[threadIdx.x] %f, p_wei[threadIdx.x] %f, p_out[threadIdx.x] %f\n", + threadIdx.x, p_in[threadIdx.x], p_wei[threadIdx.x], p_out[threadIdx.x]); + } +#endif + + const unsigned NBlockWork = (in_desc.GetLength(0) + NPerBlock - 1) / NPerBlock; + const unsigned YBlockWork = (in_desc.GetLength(2) + YPerBlock - 1) / YPerBlock; + const unsigned XBlockWork = (in_desc.GetLength(3) + XPerBlock - 1) / XPerBlock; + + const unsigned KBlockWork = (wei_desc.GetLength(1) + KPerBlock - 1) / KPerBlock; + + const unsigned block_id = + blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * (gridDim.y * gridDim.x); + + // this is ugly + DeviceTensorDescriptor<4> wei_block_desc; + wei_block_desc.mpLengths[0] = KPerBlock; + wei_block_desc.mpLengths[1] = CPerBlockLoop; + wei_block_desc.mpLengths[2] = S; + wei_block_desc.mpLengths[3] = R; + wei_block_desc.mpStrides[3] = 1; + wei_block_desc.mpStrides[2] = wei_block_desc.GetLength(3) * wei_block_desc.GetStride(3); + wei_block_desc.mpStrides[1] = wei_block_desc.GetLength(2) * wei_block_desc.GetStride(2); + wei_block_desc.mpStrides[0] = wei_block_desc.GetLength(1) * wei_block_desc.GetStride(1); + + DeviceTensorDescriptor<4> out_block_desc; + out_block_desc.mpLengths[0] = NPerBlock; + out_block_desc.mpLengths[1] = KPerBlock; + out_block_desc.mpLengths[2] = YPerBlock * OutTileSizeH; + out_block_desc.mpLengths[3] = XPerBlock * OutTileSizeW; + out_block_desc.mpStrides[3] = 1; + out_block_desc.mpStrides[2] = out_block_desc.GetLength(3) * out_block_desc.GetStride(3); + out_block_desc.mpStrides[1] = out_block_desc.GetLength(2) * out_block_desc.GetStride(2); + out_block_desc.mpStrides[0] = out_block_desc.GetLength(1) * out_block_desc.GetStride(1); + + DeviceTensorDescriptor<4> in_block_desc; + in_block_desc.mpLengths[0] = NPerBlock; + in_block_desc.mpLengths[1] = CPerBlockLoop; + in_block_desc.mpLengths[2] = YPerBlock * OutTileSizeH + S - 1; + in_block_desc.mpLengths[3] = XPerBlock * OutTileSizeW + R - 1; + in_block_desc.mpStrides[3] = 1; + in_block_desc.mpStrides[2] = in_block_desc.GetLength(3) * in_block_desc.GetStride(3); + in_block_desc.mpStrides[1] = in_block_desc.GetLength(2) * in_block_desc.GetStride(2); + in_block_desc.mpStrides[0] = in_block_desc.GetLength(1) * in_block_desc.GetStride(1); + + __shared__ TFloat p_in_block[NPerBlock * CPerBlockLoop * S * R]; + __shared__ TFloat p_wei_block[KPerBlock * CPerBlockLoop * (YPerBlock * OutTileSizeH + S - 1) * + (XPerBlock * OutTileSizeW + R - 1)]; + __shared__ TFloat p_out_block[NPerBlock * KPerBlock * (YPerBlock * OutTileSizeH) * + (XPerBlock * OutTileSizeW)]; + + unsigned itmp = block_id; + unsigned n_block_work_id = itmp / (KBlockWork * YBlockWork * XBlockWork); + itmp -= n_block_work_id * (KBlockWork * YBlockWork * XBlockWork); + unsigned k_block_work_id = itmp / (YBlockWork * XBlockWork); + itmp -= k_block_work_id * (YBlockWork * XBlockWork); + unsigned y_block_work_id = itmp / XBlockWork; + unsigned x_block_work_id = itmp - y_block_work_id * XBlockWork; + + unsigned n_block_work_begin = n_block_work_id * NPerBlock; + unsigned k_block_work_begin = k_block_work_id * KPerBlock; + unsigned y_block_work_begin = y_block_work_id * YPerBlock; + unsigned x_block_work_begin = x_block_work_id * XPerBlock; + + unsigned ho_block_work_begin = y_block_work_begin * OutTileSizeH; + unsigned wo_block_work_begin = x_block_work_begin * OutTileSizeW; + + unsigned hi_block_work_begin = ho_block_work_begin; // minus padding + unsigned wi_block_work_begin = wo_block_work_begin; // minus padding + + if(threadIdx.x == 0) + printf("====== 1:\n"); + + for(unsigned c_block_work_begin = 0; c_block_work_begin < in_desc.GetLength(1); + c_block_work_begin += CPerBlockLoop) + { + auto f_copy = [](const TFloat& src, TFloat& dst) { dst = src; }; + + // copy input tensor to LDS + blockwise_4d_tensor_op( + in_desc, + p_in + in_desc.Get1dIndex(n_block_work_begin, + c_block_work_begin, + hi_block_work_begin, + wi_block_work_begin), + in_block_desc, + p_in_block, + f_copy); + + // copy weight tensor to LDS + blockwise_4d_tensor_op( + wei_desc, + p_wei + wei_desc.Get1dIndex(k_block_work_begin, c_block_work_begin, 0, 0), + wei_block_desc, + p_wei_block, + f_copy); + + // copy output tensor to LDS + blockwise_4d_tensor_op( + out_desc, + p_out + out_desc.Get1dIndex(n_block_work_begin, + k_block_work_begin, + ho_block_work_begin, + wo_block_work_begin), + out_block_desc, + p_out_block, + f_copy); + + // blockwise convolution + blockwise_convolution( + in_block_desc, p_in_block, wei_block_desc, p_wei_block, out_block_desc, p_out_block); + + if(threadIdx.x == 0 ) + printf("====== 3:\n"); + + // accum output tensor from LDS to device mem + blockwise_4d_tensor_op( + out_block_desc, + p_out_block, + out_desc, + p_out + out_desc.Get1dIndex(n_block_work_begin, + k_block_work_begin, + ho_block_work_begin, + wo_block_work_begin), + f_copy); + } }