mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 12:30:16 +00:00
initial direct conv correct run
[ROCm/composable_kernel commit: 9657baec32]
This commit is contained in:
@@ -14,11 +14,29 @@ __device__ void blockwise_4d_tensor_op(const DeviceTensorDescriptor<4>& src_desc
|
||||
F f)
|
||||
{
|
||||
#if 1
|
||||
if(threadIdx.x < 100)
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
printf("====== blockwise_4d_tensor_op: \t"
|
||||
"threadIdx.x %u, p_src[threadIdx.x] %f, p_dst[threadIdx.x] %f\n",
|
||||
threadIdx.x, p_src[threadIdx.x], p_dst[threadIdx.x]);
|
||||
printf("blockwise_4d_tensor_op: 0: \t"
|
||||
"threadIdx.x %u \t"
|
||||
"src_desc {%u %u %u %u}, {%u %u %u %u}\t"
|
||||
"dst_desc {%u %u %u %u}, {%u %u %u %u}\n",
|
||||
threadIdx.x,
|
||||
src_desc.GetLength(0),
|
||||
src_desc.GetLength(1),
|
||||
src_desc.GetLength(2),
|
||||
src_desc.GetLength(3),
|
||||
src_desc.GetStride(0),
|
||||
src_desc.GetStride(1),
|
||||
src_desc.GetStride(2),
|
||||
src_desc.GetStride(3),
|
||||
dst_desc.GetLength(0),
|
||||
dst_desc.GetLength(1),
|
||||
dst_desc.GetLength(2),
|
||||
dst_desc.GetLength(3),
|
||||
dst_desc.GetStride(0),
|
||||
dst_desc.GetStride(1),
|
||||
dst_desc.GetStride(2),
|
||||
dst_desc.GetStride(3));
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -60,13 +78,21 @@ __device__ void blockwise_4d_tensor_op(const DeviceTensorDescriptor<4>& src_desc
|
||||
dst_desc.GetStride(0) * did0 + dst_desc.GetStride(1) * did1 +
|
||||
dst_desc.GetStride(2) * did2 + dst_desc.GetStride(3) * did3;
|
||||
|
||||
f(p_dst[dindex], p_src[sindex]);
|
||||
f(p_src[dindex], p_dst[sindex]);
|
||||
|
||||
#if 1
|
||||
printf("thread id %u, dindex %u, p_dst[dindex] %f, sindex %u, p_src[sindex] %f\n",
|
||||
threadIdx.x, dindex, p_dst[dindex], sindex, p_src[sindex]);
|
||||
// if(threadIdx.x == 0)
|
||||
{
|
||||
printf("blockwise_4d_tensor_op: 1: thread id %u, \t"
|
||||
"sindex %u, p_src[sindex] %f, \t"
|
||||
"dindex %u, p_dst[dindex] %f\n",
|
||||
threadIdx.x,
|
||||
sindex,
|
||||
p_src[sindex],
|
||||
dindex,
|
||||
p_dst[dindex]);
|
||||
}
|
||||
#endif
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -80,6 +106,33 @@ __device__ void threadwise_4d_tensor_op(const DeviceTensorDescriptor<4>& src_des
|
||||
TFloat* __restrict__ p_dst,
|
||||
F f)
|
||||
{
|
||||
#if 1
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
printf("threadwise_4d_tensor_op: 0: \t"
|
||||
"threadIdx.x %u \t"
|
||||
"src_desc {%u %u %u %u}, {%u %u %u %u}\t"
|
||||
"dst_desc {%u %u %u %u}, {%u %u %u %u}\n",
|
||||
threadIdx.x,
|
||||
src_desc.GetLength(0),
|
||||
src_desc.GetLength(1),
|
||||
src_desc.GetLength(2),
|
||||
src_desc.GetLength(3),
|
||||
src_desc.GetStride(0),
|
||||
src_desc.GetStride(1),
|
||||
src_desc.GetStride(2),
|
||||
src_desc.GetStride(3),
|
||||
dst_desc.GetLength(0),
|
||||
dst_desc.GetLength(1),
|
||||
dst_desc.GetLength(2),
|
||||
dst_desc.GetLength(3),
|
||||
dst_desc.GetStride(0),
|
||||
dst_desc.GetStride(1),
|
||||
dst_desc.GetStride(2),
|
||||
dst_desc.GetStride(3));
|
||||
}
|
||||
#endif
|
||||
|
||||
for(unsigned did0 = 0; did0 < src_desc.GetLength(0); ++did0)
|
||||
{
|
||||
for(unsigned did1 = 0; did1 < src_desc.GetLength(1); ++did1)
|
||||
@@ -96,7 +149,21 @@ __device__ void threadwise_4d_tensor_op(const DeviceTensorDescriptor<4>& src_des
|
||||
dst_desc.GetStride(0) * did0 + dst_desc.GetStride(1) * did1 +
|
||||
dst_desc.GetStride(2) * did2 + dst_desc.GetStride(3) * did3;
|
||||
|
||||
f(p_dst[dindex], p_src[sindex]);
|
||||
f(p_src[sindex], p_dst[dindex]);
|
||||
|
||||
#if 1
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
printf("threadwise_4d_tensor_op: 1: thread id %u, \t"
|
||||
"sindex %u, p_src[sindex] %f, \t"
|
||||
"dindex %u, p_dst[dindex] %f\n",
|
||||
threadIdx.x,
|
||||
sindex,
|
||||
p_src[sindex],
|
||||
dindex,
|
||||
p_dst[dindex]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -111,6 +178,72 @@ __device__ void threadwise_direct_convolution(const DeviceTensorDescriptor<4>& i
|
||||
const DeviceTensorDescriptor<4>& out_desc,
|
||||
TFloat* __restrict__ p_out)
|
||||
{
|
||||
#if 1
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
printf("threadwise_direct_convolution: 0: \t"
|
||||
"threadIdx.x %u \t"
|
||||
"in_desc {%u %u %u %u}, {%u %u %u %u}\t"
|
||||
"wei_desc {%u %u %u %u}, {%u %u %u %u}\t"
|
||||
"out_desc {%u %u %u %u}, {%u %u %u %u}\n",
|
||||
threadIdx.x,
|
||||
in_desc.GetLength(0),
|
||||
in_desc.GetLength(1),
|
||||
in_desc.GetLength(2),
|
||||
in_desc.GetLength(3),
|
||||
in_desc.GetStride(0),
|
||||
in_desc.GetStride(1),
|
||||
in_desc.GetStride(2),
|
||||
in_desc.GetStride(3),
|
||||
wei_desc.GetLength(0),
|
||||
wei_desc.GetLength(1),
|
||||
wei_desc.GetLength(2),
|
||||
wei_desc.GetLength(3),
|
||||
wei_desc.GetStride(0),
|
||||
wei_desc.GetStride(1),
|
||||
wei_desc.GetStride(2),
|
||||
wei_desc.GetStride(3),
|
||||
out_desc.GetLength(0),
|
||||
out_desc.GetLength(1),
|
||||
out_desc.GetLength(2),
|
||||
out_desc.GetLength(3),
|
||||
out_desc.GetStride(0),
|
||||
out_desc.GetStride(1),
|
||||
out_desc.GetStride(2),
|
||||
out_desc.GetStride(3));
|
||||
}
|
||||
#elif 1
|
||||
{
|
||||
printf("threadwise_direct_convolution: 0: \t"
|
||||
"threadIdx.x %u \t"
|
||||
"p_in %f %f %f %f %f %f %f %f, \t"
|
||||
"p_wei %f %f %f %f %f %f %f %f %f, \t"
|
||||
"p_out %f %f %f %f, \n",
|
||||
threadIdx.x,
|
||||
p_in[0],
|
||||
p_in[1],
|
||||
p_in[2],
|
||||
p_in[3],
|
||||
p_in[4],
|
||||
p_in[5],
|
||||
p_in[6],
|
||||
p_in[7],
|
||||
p_wei[0],
|
||||
p_wei[1],
|
||||
p_wei[2],
|
||||
p_wei[3],
|
||||
p_wei[4],
|
||||
p_wei[5],
|
||||
p_wei[6],
|
||||
p_wei[7],
|
||||
p_wei[8],
|
||||
p_out[0],
|
||||
p_out[1],
|
||||
p_out[2],
|
||||
p_out[3]);
|
||||
}
|
||||
#endif
|
||||
|
||||
for(unsigned n = 0; n < out_desc.GetLength(0); ++n)
|
||||
{
|
||||
for(unsigned k = 0; k < out_desc.GetLength(1); ++k)
|
||||
@@ -143,15 +276,20 @@ __device__ void threadwise_direct_convolution(const DeviceTensorDescriptor<4>& i
|
||||
p_out[out_index] += p_wei[wei_index] * p_in[in_index];
|
||||
|
||||
#if 1
|
||||
if(threadIdx.x == 0 )
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
printf("====== 5: \t"
|
||||
printf("threadwise_direct_convolution: 1: \t"
|
||||
"threadIdx.x %u\t"
|
||||
"out_index %u, p_out[out_index] %f, \t"
|
||||
"wei_index %u, p_wei[wei_index] %f, \t"
|
||||
"in_index %u, p_in[in_index] %f\n",
|
||||
out_index, p_out[out_index],
|
||||
wei_index, p_wei[wei_index],
|
||||
in_index, p_in[in_index]);
|
||||
threadIdx.x,
|
||||
out_index,
|
||||
p_out[out_index],
|
||||
wei_index,
|
||||
p_wei[wei_index],
|
||||
in_index,
|
||||
p_in[in_index]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -161,8 +299,6 @@ __device__ void threadwise_direct_convolution(const DeviceTensorDescriptor<4>& i
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
template <class TFloat,
|
||||
@@ -184,36 +320,87 @@ __device__ void blockwise_convolution(const DeviceTensorDescriptor<4>& in_desc,
|
||||
const DeviceTensorDescriptor<4>& out_desc,
|
||||
TFloat* __restrict__ p_out)
|
||||
{
|
||||
#if 1
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
printf("blockwise_convolution: 0: \t"
|
||||
"threadIdx.x %u \t"
|
||||
"in_desc {%u %u %u %u}, {%u %u %u %u}\t"
|
||||
"wei_desc {%u %u %u %u}, {%u %u %u %u}\t"
|
||||
"out_desc {%u %u %u %u}, {%u %u %u %u}\n",
|
||||
threadIdx.x,
|
||||
in_desc.GetLength(0),
|
||||
in_desc.GetLength(1),
|
||||
in_desc.GetLength(2),
|
||||
in_desc.GetLength(3),
|
||||
in_desc.GetStride(0),
|
||||
in_desc.GetStride(1),
|
||||
in_desc.GetStride(2),
|
||||
in_desc.GetStride(3),
|
||||
wei_desc.GetLength(0),
|
||||
wei_desc.GetLength(1),
|
||||
wei_desc.GetLength(2),
|
||||
wei_desc.GetLength(3),
|
||||
wei_desc.GetStride(0),
|
||||
wei_desc.GetStride(1),
|
||||
wei_desc.GetStride(2),
|
||||
wei_desc.GetStride(3),
|
||||
out_desc.GetLength(0),
|
||||
out_desc.GetLength(1),
|
||||
out_desc.GetLength(2),
|
||||
out_desc.GetLength(3),
|
||||
out_desc.GetStride(0),
|
||||
out_desc.GetStride(1),
|
||||
out_desc.GetStride(2),
|
||||
out_desc.GetStride(3));
|
||||
}
|
||||
#endif
|
||||
|
||||
// for now, one thread do 1 N and 1 K
|
||||
DeviceTensorDescriptor<4> wei_thread_desc;
|
||||
wei_thread_desc.mpLengths[0] = 1;
|
||||
wei_thread_desc.mpLengths[1] = CPerBlockLoop;
|
||||
wei_thread_desc.mpLengths[2] = S;
|
||||
wei_thread_desc.mpLengths[3] = R;
|
||||
wei_thread_desc.mpStrides[3] = 1;
|
||||
wei_thread_desc.mpStrides[2] = wei_thread_desc.GetLength(3) * wei_thread_desc.GetStride(3);
|
||||
wei_thread_desc.mpStrides[1] = wei_thread_desc.GetLength(2) * wei_thread_desc.GetStride(2);
|
||||
wei_thread_desc.mpStrides[0] = wei_thread_desc.GetLength(1) * wei_thread_desc.GetStride(1);
|
||||
DeviceTensorDescriptor<4> in_thread_src_desc = in_desc;
|
||||
in_thread_src_desc.mpLengths[0] = 1;
|
||||
in_thread_src_desc.mpLengths[1] = CPerBlockLoop;
|
||||
in_thread_src_desc.mpLengths[2] = OutTileSizeH + S - 1;
|
||||
in_thread_src_desc.mpLengths[3] = OutTileSizeW + R - 1;
|
||||
|
||||
DeviceTensorDescriptor<4> out_thread_desc;
|
||||
out_thread_desc.mpLengths[0] = 1;
|
||||
out_thread_desc.mpLengths[1] = 1;
|
||||
out_thread_desc.mpLengths[2] = OutTileSizeH;
|
||||
out_thread_desc.mpLengths[3] = OutTileSizeW;
|
||||
out_thread_desc.mpStrides[3] = 1;
|
||||
out_thread_desc.mpStrides[2] = out_thread_desc.GetLength(3) * out_thread_desc.GetStride(3);
|
||||
out_thread_desc.mpStrides[1] = out_thread_desc.GetLength(2) * out_thread_desc.GetStride(2);
|
||||
out_thread_desc.mpStrides[0] = out_thread_desc.GetLength(1) * out_thread_desc.GetStride(1);
|
||||
DeviceTensorDescriptor<4> wei_thread_src_desc = wei_desc;
|
||||
wei_thread_src_desc.mpLengths[0] = 1;
|
||||
wei_thread_src_desc.mpLengths[1] = CPerBlockLoop;
|
||||
wei_thread_src_desc.mpLengths[2] = S;
|
||||
wei_thread_src_desc.mpLengths[3] = R;
|
||||
|
||||
DeviceTensorDescriptor<4> in_thread_desc;
|
||||
in_thread_desc.mpLengths[0] = 1;
|
||||
in_thread_desc.mpLengths[1] = CPerBlockLoop;
|
||||
in_thread_desc.mpLengths[2] = OutTileSizeH + S - 1;
|
||||
in_thread_desc.mpLengths[3] = OutTileSizeW + R - 1;
|
||||
in_thread_desc.mpStrides[3] = 1;
|
||||
in_thread_desc.mpStrides[2] = in_thread_desc.GetLength(3) * in_thread_desc.GetStride(3);
|
||||
in_thread_desc.mpStrides[1] = in_thread_desc.GetLength(2) * in_thread_desc.GetStride(2);
|
||||
in_thread_desc.mpStrides[0] = in_thread_desc.GetLength(1) * in_thread_desc.GetStride(1);
|
||||
DeviceTensorDescriptor<4> out_thread_src_desc = out_desc;
|
||||
out_thread_src_desc.mpLengths[0] = 1;
|
||||
out_thread_src_desc.mpLengths[1] = 1;
|
||||
out_thread_src_desc.mpLengths[2] = OutTileSizeH;
|
||||
out_thread_src_desc.mpLengths[3] = OutTileSizeW;
|
||||
|
||||
DeviceTensorDescriptor<4> in_thread_dst_desc = in_thread_src_desc;
|
||||
in_thread_dst_desc.mpStrides[3] = 1;
|
||||
in_thread_dst_desc.mpStrides[2] =
|
||||
in_thread_dst_desc.GetLength(3) * in_thread_dst_desc.GetStride(3);
|
||||
in_thread_dst_desc.mpStrides[1] =
|
||||
in_thread_dst_desc.GetLength(2) * in_thread_dst_desc.GetStride(2);
|
||||
in_thread_dst_desc.mpStrides[0] =
|
||||
in_thread_dst_desc.GetLength(1) * in_thread_dst_desc.GetStride(1);
|
||||
|
||||
DeviceTensorDescriptor<4> wei_thread_dst_desc = wei_thread_src_desc;
|
||||
wei_thread_dst_desc.mpStrides[3] = 1;
|
||||
wei_thread_dst_desc.mpStrides[2] =
|
||||
wei_thread_dst_desc.GetLength(3) * wei_thread_dst_desc.GetStride(3);
|
||||
wei_thread_dst_desc.mpStrides[1] =
|
||||
wei_thread_dst_desc.GetLength(2) * wei_thread_dst_desc.GetStride(2);
|
||||
wei_thread_dst_desc.mpStrides[0] =
|
||||
wei_thread_dst_desc.GetLength(1) * wei_thread_dst_desc.GetStride(1);
|
||||
|
||||
DeviceTensorDescriptor<4> out_thread_dst_desc = out_thread_src_desc;
|
||||
out_thread_dst_desc.mpStrides[3] = 1;
|
||||
out_thread_dst_desc.mpStrides[2] =
|
||||
out_thread_dst_desc.GetLength(3) * out_thread_dst_desc.GetStride(3);
|
||||
out_thread_dst_desc.mpStrides[1] =
|
||||
out_thread_dst_desc.GetLength(2) * out_thread_dst_desc.GetStride(2);
|
||||
out_thread_dst_desc.mpStrides[0] =
|
||||
out_thread_dst_desc.GetLength(1) * out_thread_dst_desc.GetStride(1);
|
||||
|
||||
const unsigned thread_sz = blockDim.x * blockDim.y * blockDim.z;
|
||||
|
||||
@@ -248,45 +435,45 @@ __device__ void blockwise_convolution(const DeviceTensorDescriptor<4>& in_desc,
|
||||
|
||||
// copy input tensor into register
|
||||
threadwise_4d_tensor_op<TFloat, decltype(f_copy)>(
|
||||
in_desc,
|
||||
in_thread_src_desc,
|
||||
p_in + in_desc.Get1dIndex(
|
||||
n_thread_work_begin, 0, hi_thread_work_begin, wi_thread_work_begin),
|
||||
in_thread_desc,
|
||||
in_thread_dst_desc,
|
||||
p_in_thread,
|
||||
f_copy);
|
||||
|
||||
// copy weight tensor into register
|
||||
threadwise_4d_tensor_op<TFloat, decltype(f_copy)>(
|
||||
wei_desc,
|
||||
p_wei + wei_thread_desc.Get1dIndex(k_thread_work_begin, 0, 0, 0),
|
||||
wei_thread_desc,
|
||||
wei_thread_src_desc,
|
||||
p_wei + wei_desc.Get1dIndex(k_thread_work_begin, 0, 0, 0),
|
||||
wei_thread_dst_desc,
|
||||
p_wei_thread,
|
||||
f_copy);
|
||||
|
||||
// copy output tensor into register
|
||||
threadwise_4d_tensor_op<TFloat, decltype(f_copy)>(
|
||||
out_desc,
|
||||
out_thread_src_desc,
|
||||
p_out + out_desc.Get1dIndex(n_thread_work_begin,
|
||||
k_thread_work_begin,
|
||||
ho_thread_work_begin,
|
||||
wo_thread_work_begin),
|
||||
out_thread_desc,
|
||||
out_thread_dst_desc,
|
||||
p_out_thread,
|
||||
f_copy);
|
||||
|
||||
// threadwise convolution
|
||||
threadwise_direct_convolution(in_thread_desc,
|
||||
threadwise_direct_convolution(in_thread_dst_desc,
|
||||
p_in_thread,
|
||||
wei_thread_desc,
|
||||
wei_thread_dst_desc,
|
||||
p_wei_thread,
|
||||
out_thread_desc,
|
||||
out_thread_dst_desc,
|
||||
p_out_thread);
|
||||
|
||||
// accumulate output tensor into device mem
|
||||
threadwise_4d_tensor_op<TFloat, decltype(f_copy)>(
|
||||
out_thread_desc,
|
||||
out_thread_dst_desc,
|
||||
p_out_thread,
|
||||
out_desc,
|
||||
out_thread_src_desc,
|
||||
p_out + out_desc.Get1dIndex(n_thread_work_begin,
|
||||
k_thread_work_begin,
|
||||
ho_thread_work_begin,
|
||||
@@ -315,11 +502,38 @@ __global__ void gridwise_convolution(const DeviceTensorDescriptor<4> in_desc,
|
||||
TFloat* __restrict__ p_out)
|
||||
{
|
||||
#if 1
|
||||
if(threadIdx.x < 100)
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
printf("====== 0: \t"
|
||||
"threadIdx.x %u, p_in[threadIdx.x] %f, p_wei[threadIdx.x] %f, p_out[threadIdx.x] %f\n",
|
||||
threadIdx.x, p_in[threadIdx.x], p_wei[threadIdx.x], p_out[threadIdx.x]);
|
||||
printf("gridwise_convolution: 0: \t"
|
||||
"threadIdx.x %u \t"
|
||||
"in_desc {%u %u %u %u}, {%u %u %u %u}\t"
|
||||
"wei_desc {%u %u %u %u}, {%u %u %u %u}\t"
|
||||
"out_desc {%u %u %u %u}, {%u %u %u %u}\n",
|
||||
threadIdx.x,
|
||||
in_desc.GetLength(0),
|
||||
in_desc.GetLength(1),
|
||||
in_desc.GetLength(2),
|
||||
in_desc.GetLength(3),
|
||||
in_desc.GetStride(0),
|
||||
in_desc.GetStride(1),
|
||||
in_desc.GetStride(2),
|
||||
in_desc.GetStride(3),
|
||||
wei_desc.GetLength(0),
|
||||
wei_desc.GetLength(1),
|
||||
wei_desc.GetLength(2),
|
||||
wei_desc.GetLength(3),
|
||||
wei_desc.GetStride(0),
|
||||
wei_desc.GetStride(1),
|
||||
wei_desc.GetStride(2),
|
||||
wei_desc.GetStride(3),
|
||||
out_desc.GetLength(0),
|
||||
out_desc.GetLength(1),
|
||||
out_desc.GetLength(2),
|
||||
out_desc.GetLength(3),
|
||||
out_desc.GetStride(0),
|
||||
out_desc.GetStride(1),
|
||||
out_desc.GetStride(2),
|
||||
out_desc.GetStride(3));
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -363,9 +577,9 @@ __global__ void gridwise_convolution(const DeviceTensorDescriptor<4> in_desc,
|
||||
in_block_desc.mpStrides[1] = in_block_desc.GetLength(2) * in_block_desc.GetStride(2);
|
||||
in_block_desc.mpStrides[0] = in_block_desc.GetLength(1) * in_block_desc.GetStride(1);
|
||||
|
||||
__shared__ TFloat p_in_block[NPerBlock * CPerBlockLoop * S * R];
|
||||
__shared__ TFloat p_wei_block[KPerBlock * CPerBlockLoop * (YPerBlock * OutTileSizeH + S - 1) *
|
||||
(XPerBlock * OutTileSizeW + R - 1)];
|
||||
__shared__ TFloat p_in_block[NPerBlock * CPerBlockLoop * (YPerBlock * OutTileSizeH + S - 1) *
|
||||
(XPerBlock * OutTileSizeW + R - 1)];
|
||||
__shared__ TFloat p_wei_block[KPerBlock * CPerBlockLoop * S * R];
|
||||
__shared__ TFloat p_out_block[NPerBlock * KPerBlock * (YPerBlock * OutTileSizeH) *
|
||||
(XPerBlock * OutTileSizeW)];
|
||||
|
||||
@@ -388,9 +602,6 @@ __global__ void gridwise_convolution(const DeviceTensorDescriptor<4> in_desc,
|
||||
unsigned hi_block_work_begin = ho_block_work_begin; // minus padding
|
||||
unsigned wi_block_work_begin = wo_block_work_begin; // minus padding
|
||||
|
||||
if(threadIdx.x == 0)
|
||||
printf("====== 1:\n");
|
||||
|
||||
for(unsigned c_block_work_begin = 0; c_block_work_begin < in_desc.GetLength(1);
|
||||
c_block_work_begin += CPerBlockLoop)
|
||||
{
|
||||
@@ -426,6 +637,8 @@ __global__ void gridwise_convolution(const DeviceTensorDescriptor<4> in_desc,
|
||||
p_out_block,
|
||||
f_copy);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// blockwise convolution
|
||||
blockwise_convolution<TFloat,
|
||||
S,
|
||||
@@ -441,8 +654,7 @@ __global__ void gridwise_convolution(const DeviceTensorDescriptor<4> in_desc,
|
||||
CPerBlockLoop>(
|
||||
in_block_desc, p_in_block, wei_block_desc, p_wei_block, out_block_desc, p_out_block);
|
||||
|
||||
if(threadIdx.x == 0 )
|
||||
printf("====== 3:\n");
|
||||
__syncthreads();
|
||||
|
||||
// accum output tensor from LDS to device mem
|
||||
blockwise_4d_tensor_op<TFloat, 1, 1, 1, 64, decltype(f_copy)>(
|
||||
|
||||
Reference in New Issue
Block a user