mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
add 2nd version of blockwise_tensor_op
This commit is contained in:
155
src/include/blockwise_tensor_op.cuh
Normal file
155
src/include/blockwise_tensor_op.cuh
Normal file
@@ -0,0 +1,155 @@
|
||||
#pragma once
|
||||
#include "constant_tensor_descriptor.cuh"
|
||||
|
||||
#if 0
|
||||
template <class TFloat,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
unsigned NWorkLen0,
|
||||
unsigned NWorkLen1,
|
||||
unsigned NWorkLen2,
|
||||
unsigned NWorkLen3,
|
||||
class F,
|
||||
unsigned BlockSize>
|
||||
__device__ void blockwise_4d_tensor_op(
|
||||
SrcDesc, TFloat* const __restrict__ p_src, DstDesc, TFloat* __restrict__ p_dst, F f)
|
||||
{
|
||||
constexpr auto I0 = Index<0>{};
|
||||
constexpr auto I1 = Index<1>{};
|
||||
constexpr auto I2 = Index<2>{};
|
||||
constexpr auto I3 = Index<3>{};
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
|
||||
static_assert(is_same<decltype(src_desc.GetLengths()), decltype(dst_desc.GetLengths())>::value);
|
||||
|
||||
#if 0
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(src_desc, "blockwise_4d_tensor_op: src_desc: ");
|
||||
print_ConstantTensorDescriptor(dst_desc, "blockwise_4d_tensor_op: dst_desc: ");
|
||||
}
|
||||
#endif
|
||||
|
||||
constexpr unsigned NWorkStride3 = 1;
|
||||
constexpr unsigned NWorkStride2 = NWorkLen3 * NWorkStride3;
|
||||
constexpr unsigned NWorkStride1 = NWorkLen2 * NWorkStride2;
|
||||
constexpr unsigned NWorkStride0 = NWorkLen1 * NWorkStride1;
|
||||
|
||||
unsigned itmp =
|
||||
threadIdx.x;
|
||||
|
||||
const unsigned did0_begin = itmp / NWorkStride0;
|
||||
|
||||
itmp -= did0_begin * NWorkStride0;
|
||||
|
||||
const unsigned did1_begin = itmp / NWorkStride1;
|
||||
|
||||
itmp -= did1_begin * NWorkStride1;
|
||||
|
||||
const unsigned did2_begin = itmp / NWorkStride2;
|
||||
|
||||
itmp -= did2_begin * NWorkStride2;
|
||||
|
||||
const unsigned did3_begin = itmp / NWorkStride3;
|
||||
|
||||
for(unsigned did0 = did0_begin; did0 < src_desc.GetLength(I0); did0 += NWorkLen0)
|
||||
{
|
||||
for(unsigned did1 = did1_begin; did1 < src_desc.GetLength(I1); did1 += NWorkLen1)
|
||||
{
|
||||
for(unsigned did2 = did2_begin; did2 < src_desc.GetLength(I2); did2 += NWorkLen2)
|
||||
{
|
||||
for(unsigned did3 = did3_begin; did3 < src_desc.GetLength(I3); did3 += NWorkLen3)
|
||||
{
|
||||
const unsigned sindex =
|
||||
src_desc.GetStride(I0) * did0 + src_desc.GetStride(I1) * did1 +
|
||||
src_desc.GetStride(I2) * did2 + src_desc.GetStride(I3) * did3;
|
||||
|
||||
const unsigned dindex =
|
||||
dst_desc.GetStride(I0) * did0 + dst_desc.GetStride(I1) * did1 +
|
||||
dst_desc.GetStride(I2) * did2 + dst_desc.GetStride(I3) * did3;
|
||||
|
||||
f(p_src[dindex], p_dst[sindex]);
|
||||
|
||||
#if 0
|
||||
// if(threadIdx.x == 0)
|
||||
{
|
||||
printf("blockwise_4d_tensor_op: 1: thread id %u, \t"
|
||||
"sindex %u, p_src[sindex] %f, \t"
|
||||
"dindex %u, p_dst[dindex] %f\n",
|
||||
threadIdx.x,
|
||||
sindex,
|
||||
p_src[sindex],
|
||||
dindex,
|
||||
p_dst[dindex]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#elif 1
|
||||
|
||||
template <class TFloat,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
unsigned NWorkLen0,
|
||||
unsigned NWorkLen1,
|
||||
unsigned NWorkLen2,
|
||||
unsigned NWorkLen3,
|
||||
class F,
|
||||
unsigned BlockSize>
|
||||
__device__ void blockwise_4d_tensor_op(
|
||||
SrcDesc, TFloat* const __restrict__ p_src, DstDesc, TFloat* __restrict__ p_dst, F f)
|
||||
{
|
||||
constexpr auto I0 = Index<0>{};
|
||||
constexpr auto I1 = Index<1>{};
|
||||
constexpr auto I2 = Index<2>{};
|
||||
constexpr auto I3 = Index<3>{};
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
|
||||
static_assert(is_same<decltype(src_desc.GetLengths()), decltype(dst_desc.GetLengths())>::value);
|
||||
|
||||
#if 0
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(src_desc, "blockwise_4d_tensor_op: src_desc: ");
|
||||
print_ConstantTensorDescriptor(dst_desc, "blockwise_4d_tensor_op: dst_desc: ");
|
||||
}
|
||||
#endif
|
||||
|
||||
unsigned lid = threadIdx.x;
|
||||
|
||||
for(unsigned i = lid; i < src_desc.GetElementSize(); i += BlockSize)
|
||||
{
|
||||
unsigned is = i;
|
||||
|
||||
const unsigned did0 = is / src_desc.GetStride(I0);
|
||||
|
||||
is -= did0 * src_desc.GetStride(I0);
|
||||
|
||||
const unsigned did1 = is / src_desc.GetStride(I1);
|
||||
|
||||
is -= did1 * src_desc.GetStride(I1);
|
||||
|
||||
const unsigned did2 = is / src_desc.GetStride(I2);
|
||||
|
||||
is -= did2 * src_desc.GetStride(I2);
|
||||
|
||||
const unsigned did3 = is / src_desc.GetStride(I3);
|
||||
|
||||
const unsigned sindex = src_desc.GetStride(I0) * did0 + src_desc.GetStride(I1) * did1 +
|
||||
src_desc.GetStride(I2) * did2 + src_desc.GetStride(I3) * did3;
|
||||
|
||||
const unsigned dindex = dst_desc.GetStride(I0) * did0 + dst_desc.GetStride(I1) * did1 +
|
||||
dst_desc.GetStride(I2) * did2 + dst_desc.GetStride(I3) * did3;
|
||||
|
||||
f(p_src[sindex], p_dst[dindex]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
13
src/include/common.cuh
Normal file
13
src/include/common.cuh
Normal file
@@ -0,0 +1,13 @@
|
||||
#pragma once
|
||||
|
||||
template <class T1, class T2>
|
||||
struct is_same
|
||||
{
|
||||
static const bool value = false;
|
||||
};
|
||||
|
||||
template <class T>
|
||||
struct is_same<T, T>
|
||||
{
|
||||
static const bool value = true;
|
||||
};
|
||||
@@ -1,5 +1,5 @@
|
||||
#pragma once
|
||||
#include "helper_cuda.h"
|
||||
#include "common.cuh"
|
||||
|
||||
template <class T, T N>
|
||||
struct Constant
|
||||
|
||||
@@ -1,148 +1,7 @@
|
||||
#pragma once
|
||||
#include "constant_tensor_descriptor.cuh"
|
||||
|
||||
template <class TFloat,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
unsigned NWorkLen0,
|
||||
unsigned NWorkLen1,
|
||||
unsigned NWorkLen2,
|
||||
unsigned NWorkLen3,
|
||||
class F>
|
||||
__device__ void blockwise_4d_tensor_op(
|
||||
SrcDesc, TFloat* const __restrict__ p_src, DstDesc, TFloat* __restrict__ p_dst, F f)
|
||||
{
|
||||
constexpr auto I0 = Index<0>{};
|
||||
constexpr auto I1 = Index<1>{};
|
||||
constexpr auto I2 = Index<2>{};
|
||||
constexpr auto I3 = Index<3>{};
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
|
||||
#if 0
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(src_desc, "blockwise_4d_tensor_op: src_desc: ");
|
||||
print_ConstantTensorDescriptor(dst_desc, "blockwise_4d_tensor_op: dst_desc: ");
|
||||
}
|
||||
#endif
|
||||
|
||||
constexpr unsigned NWorkStride3 = 1;
|
||||
constexpr unsigned NWorkStride2 = NWorkLen3 * NWorkStride3;
|
||||
constexpr unsigned NWorkStride1 = NWorkLen2 * NWorkStride2;
|
||||
constexpr unsigned NWorkStride0 = NWorkLen1 * NWorkStride1;
|
||||
|
||||
unsigned itmp =
|
||||
threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * (blockDim.y * blockDim.x);
|
||||
|
||||
const unsigned did0_begin = itmp / NWorkStride0;
|
||||
|
||||
itmp -= did0_begin * NWorkStride0;
|
||||
|
||||
const unsigned did1_begin = itmp / NWorkStride1;
|
||||
|
||||
itmp -= did1_begin * NWorkStride1;
|
||||
|
||||
const unsigned did2_begin = itmp / NWorkStride2;
|
||||
|
||||
itmp -= did2_begin * NWorkStride2;
|
||||
|
||||
const unsigned did3_begin = itmp / NWorkStride3;
|
||||
|
||||
for(unsigned did0 = did0_begin; did0 < src_desc.GetLength(I0); did0 += NWorkLen0)
|
||||
{
|
||||
for(unsigned did1 = did1_begin; did1 < src_desc.GetLength(I1); did1 += NWorkLen1)
|
||||
{
|
||||
for(unsigned did2 = did2_begin; did2 < src_desc.GetLength(I2); did2 += NWorkLen2)
|
||||
{
|
||||
for(unsigned did3 = did3_begin; did3 < src_desc.GetLength(I3); did3 += NWorkLen3)
|
||||
{
|
||||
const unsigned sindex =
|
||||
src_desc.GetStride(I0) * did0 + src_desc.GetStride(I1) * did1 +
|
||||
src_desc.GetStride(I2) * did2 + src_desc.GetStride(I3) * did3;
|
||||
|
||||
const unsigned dindex =
|
||||
dst_desc.GetStride(I0) * did0 + dst_desc.GetStride(I1) * did1 +
|
||||
dst_desc.GetStride(I2) * did2 + dst_desc.GetStride(I3) * did3;
|
||||
|
||||
f(p_src[dindex], p_dst[sindex]);
|
||||
|
||||
#if 0
|
||||
// if(threadIdx.x == 0)
|
||||
{
|
||||
printf("blockwise_4d_tensor_op: 1: thread id %u, \t"
|
||||
"sindex %u, p_src[sindex] %f, \t"
|
||||
"dindex %u, p_dst[dindex] %f\n",
|
||||
threadIdx.x,
|
||||
sindex,
|
||||
p_src[sindex],
|
||||
dindex,
|
||||
p_dst[dindex]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class TFloat, class SrcDesc, class DstDesc, class F>
|
||||
__device__ void threadwise_4d_tensor_op(
|
||||
SrcDesc, TFloat* const __restrict__ p_src, DstDesc, TFloat* __restrict__ p_dst, F f)
|
||||
{
|
||||
constexpr auto I0 = Index<0>{};
|
||||
constexpr auto I1 = Index<1>{};
|
||||
constexpr auto I2 = Index<2>{};
|
||||
constexpr auto I3 = Index<3>{};
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
|
||||
#if 0
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(src_desc);
|
||||
print_ConstantTensorDescriptor(dst_desc);
|
||||
}
|
||||
#endif
|
||||
|
||||
for(unsigned did0 = 0; did0 < src_desc.GetLength(I0); ++did0)
|
||||
{
|
||||
for(unsigned did1 = 0; did1 < src_desc.GetLength(I1); ++did1)
|
||||
{
|
||||
for(unsigned did2 = 0; did2 < src_desc.GetLength(I2); ++did2)
|
||||
{
|
||||
for(unsigned did3 = 0; did3 < src_desc.GetLength(I3); ++did3)
|
||||
{
|
||||
const unsigned sindex =
|
||||
src_desc.GetStride(I0) * did0 + src_desc.GetStride(I1) * did1 +
|
||||
src_desc.GetStride(I2) * did2 + src_desc.GetStride(I3) * did3;
|
||||
|
||||
const unsigned dindex =
|
||||
dst_desc.GetStride(I0) * did0 + dst_desc.GetStride(I1) * did1 +
|
||||
dst_desc.GetStride(I2) * did2 + dst_desc.GetStride(I3) * did3;
|
||||
|
||||
f(p_src[sindex], p_dst[dindex]);
|
||||
|
||||
#if 0
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
printf("threadwise_4d_tensor_op: 1: thread id %u, \t"
|
||||
"sindex %u, p_src[sindex] %f, \t"
|
||||
"dindex %u, p_dst[dindex] %f\n",
|
||||
threadIdx.x,
|
||||
sindex,
|
||||
p_src[sindex],
|
||||
dindex,
|
||||
p_dst[dindex]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#include "blockwise_tensor_op.cuh"
|
||||
#include "threadwise_tensor_op.cuh"
|
||||
|
||||
template <class TFloat, class InDesc, class WeiDesc, class OutDesc>
|
||||
__device__ void threadwise_direct_convolution(InDesc,
|
||||
@@ -232,7 +91,8 @@ template <class TFloat,
|
||||
class WeiDesc,
|
||||
class OutDesc,
|
||||
unsigned OutTileSizeH,
|
||||
unsigned OutTileSizeW>
|
||||
unsigned OutTileSizeW,
|
||||
unsigned BlockSize>
|
||||
__device__ void blockwise_convolution(InDesc,
|
||||
TFloat* const __restrict__ p_in,
|
||||
WeiDesc,
|
||||
@@ -290,14 +150,11 @@ __device__ void blockwise_convolution(InDesc,
|
||||
constexpr auto out_thread_dst_desc =
|
||||
make_ConstantTensorDescriptor(out_thread_src_desc.GetLengths());
|
||||
|
||||
const unsigned thread_sz = blockDim.x * blockDim.y * blockDim.z;
|
||||
|
||||
const unsigned thread_id =
|
||||
threadIdx.x + threadIdx.y * blockDim.x + threadIdx.z * (blockDim.y * blockDim.x);
|
||||
const unsigned thread_id = threadIdx.x;
|
||||
|
||||
for(unsigned thread_work_id = thread_id;
|
||||
thread_work_id < NPerBlock * KPerBlock * YPerBlock * XPerBlock;
|
||||
thread_work_id += thread_sz)
|
||||
thread_work_id += BlockSize)
|
||||
{
|
||||
unsigned itmp = thread_work_id;
|
||||
unsigned n_thread_work_id = itmp / (KPerBlock * YPerBlock * XPerBlock);
|
||||
@@ -397,7 +254,9 @@ template <class TFloat,
|
||||
unsigned NBlockCopyLen0,
|
||||
unsigned NBlockCopyLen1,
|
||||
unsigned NBlockCopyLen2,
|
||||
unsigned NBlockCopyLen3>
|
||||
unsigned NBlockCopyLen3,
|
||||
unsigned BlockSize,
|
||||
unsigned GridSize>
|
||||
__global__ void gridwise_convolution(InDesc,
|
||||
TFloat* const __restrict__ p_in,
|
||||
WeiDesc,
|
||||
@@ -452,8 +311,7 @@ __global__ void gridwise_convolution(InDesc,
|
||||
__shared__ TFloat p_wei_block[wei_block_size];
|
||||
__shared__ TFloat p_out_block[out_block_size];
|
||||
|
||||
const unsigned block_id =
|
||||
blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * (gridDim.y * gridDim.x);
|
||||
const unsigned block_id = blockIdx.x;
|
||||
|
||||
unsigned itmp = block_id;
|
||||
unsigned n_block_work_id = itmp / (KBlockWork * YBlockWork * XBlockWork);
|
||||
@@ -515,17 +373,16 @@ __global__ void gridwise_convolution(InDesc,
|
||||
NBlockCopyLen1,
|
||||
NBlockCopyLen2,
|
||||
NBlockCopyLen3,
|
||||
decltype(f_copy)>(
|
||||
in_block_glb_desc,
|
||||
p_in + in_block_glb_desc.Get1dIndex(n_block_work_begin,
|
||||
c_block_work_begin,
|
||||
hi_block_work_begin,
|
||||
wi_block_work_begin),
|
||||
in_block_lds_desc,
|
||||
p_in_block,
|
||||
f_copy);
|
||||
decltype(f_copy),
|
||||
BlockSize>(in_block_glb_desc,
|
||||
p_in + in_block_glb_desc.Get1dIndex(n_block_work_begin,
|
||||
c_block_work_begin,
|
||||
hi_block_work_begin,
|
||||
wi_block_work_begin),
|
||||
in_block_lds_desc,
|
||||
p_in_block,
|
||||
f_copy);
|
||||
|
||||
#if 1
|
||||
// copy weight tensor to LDS
|
||||
blockwise_4d_tensor_op<TFloat,
|
||||
decltype(wei_block_glb_desc),
|
||||
@@ -534,7 +391,8 @@ __global__ void gridwise_convolution(InDesc,
|
||||
NBlockCopyLen1,
|
||||
NBlockCopyLen2,
|
||||
NBlockCopyLen3,
|
||||
decltype(f_copy)>(
|
||||
decltype(f_copy),
|
||||
BlockSize>(
|
||||
wei_block_glb_desc,
|
||||
p_wei + wei_block_glb_desc.Get1dIndex(k_block_work_begin, c_block_work_begin, 0, 0),
|
||||
wei_block_lds_desc,
|
||||
@@ -549,17 +407,18 @@ __global__ void gridwise_convolution(InDesc,
|
||||
NBlockCopyLen1,
|
||||
NBlockCopyLen2,
|
||||
NBlockCopyLen3,
|
||||
decltype(f_copy)>(
|
||||
out_block_glb_desc,
|
||||
p_out + out_block_glb_desc.Get1dIndex(n_block_work_begin,
|
||||
k_block_work_begin,
|
||||
ho_block_work_begin,
|
||||
wo_block_work_begin),
|
||||
out_block_lds_desc,
|
||||
p_out_block,
|
||||
f_copy);
|
||||
decltype(f_copy),
|
||||
BlockSize>(out_block_glb_desc,
|
||||
p_out +
|
||||
out_block_glb_desc.Get1dIndex(n_block_work_begin,
|
||||
k_block_work_begin,
|
||||
ho_block_work_begin,
|
||||
wo_block_work_begin),
|
||||
out_block_lds_desc,
|
||||
p_out_block,
|
||||
f_copy);
|
||||
|
||||
#if 0
|
||||
#if 1
|
||||
__syncthreads();
|
||||
#endif
|
||||
|
||||
@@ -569,14 +428,15 @@ __global__ void gridwise_convolution(InDesc,
|
||||
decltype(wei_block_lds_desc),
|
||||
decltype(out_block_lds_desc),
|
||||
OutTileSizeH,
|
||||
OutTileSizeW>(in_block_lds_desc,
|
||||
p_in_block,
|
||||
wei_block_lds_desc,
|
||||
p_wei_block,
|
||||
out_block_lds_desc,
|
||||
p_out_block);
|
||||
OutTileSizeW,
|
||||
BlockSize>(in_block_lds_desc,
|
||||
p_in_block,
|
||||
wei_block_lds_desc,
|
||||
p_wei_block,
|
||||
out_block_lds_desc,
|
||||
p_out_block);
|
||||
|
||||
#if 0
|
||||
#if 1
|
||||
__syncthreads();
|
||||
#endif
|
||||
|
||||
@@ -588,15 +448,15 @@ __global__ void gridwise_convolution(InDesc,
|
||||
NBlockCopyLen1,
|
||||
NBlockCopyLen2,
|
||||
NBlockCopyLen3,
|
||||
decltype(f_copy)>(
|
||||
out_block_lds_desc,
|
||||
p_out_block,
|
||||
out_block_glb_desc,
|
||||
p_out + out_block_glb_desc.Get1dIndex(n_block_work_begin,
|
||||
k_block_work_begin,
|
||||
ho_block_work_begin,
|
||||
wo_block_work_begin),
|
||||
f_copy);
|
||||
#endif
|
||||
decltype(f_copy),
|
||||
BlockSize>(out_block_lds_desc,
|
||||
p_out_block,
|
||||
out_block_glb_desc,
|
||||
p_out +
|
||||
out_block_glb_desc.Get1dIndex(n_block_work_begin,
|
||||
k_block_work_begin,
|
||||
ho_block_work_begin,
|
||||
wo_block_work_begin),
|
||||
f_copy);
|
||||
}
|
||||
}
|
||||
|
||||
61
src/include/threadwise_tensor_op.cuh
Normal file
61
src/include/threadwise_tensor_op.cuh
Normal file
@@ -0,0 +1,61 @@
|
||||
#pragma once
|
||||
#include "constant_tensor_descriptor.cuh"
|
||||
|
||||
template <class TFloat, class SrcDesc, class DstDesc, class F>
|
||||
__device__ void threadwise_4d_tensor_op(
|
||||
SrcDesc, TFloat* const __restrict__ p_src, DstDesc, TFloat* __restrict__ p_dst, F f)
|
||||
{
|
||||
constexpr auto I0 = Index<0>{};
|
||||
constexpr auto I1 = Index<1>{};
|
||||
constexpr auto I2 = Index<2>{};
|
||||
constexpr auto I3 = Index<3>{};
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
|
||||
static_assert(is_same<decltype(src_desc.GetLengths()), decltype(dst_desc.GetLengths())>::value);
|
||||
|
||||
#if 0
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(src_desc);
|
||||
print_ConstantTensorDescriptor(dst_desc);
|
||||
}
|
||||
#endif
|
||||
|
||||
for(unsigned did0 = 0; did0 < src_desc.GetLength(I0); ++did0)
|
||||
{
|
||||
for(unsigned did1 = 0; did1 < src_desc.GetLength(I1); ++did1)
|
||||
{
|
||||
for(unsigned did2 = 0; did2 < src_desc.GetLength(I2); ++did2)
|
||||
{
|
||||
for(unsigned did3 = 0; did3 < src_desc.GetLength(I3); ++did3)
|
||||
{
|
||||
const unsigned sindex =
|
||||
src_desc.GetStride(I0) * did0 + src_desc.GetStride(I1) * did1 +
|
||||
src_desc.GetStride(I2) * did2 + src_desc.GetStride(I3) * did3;
|
||||
|
||||
const unsigned dindex =
|
||||
dst_desc.GetStride(I0) * did0 + dst_desc.GetStride(I1) * did1 +
|
||||
dst_desc.GetStride(I2) * did2 + dst_desc.GetStride(I3) * did3;
|
||||
|
||||
f(p_src[sindex], p_dst[dindex]);
|
||||
|
||||
#if 0
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
printf("threadwise_4d_tensor_op: 1: thread id %u, \t"
|
||||
"sindex %u, p_src[sindex] %f, \t"
|
||||
"dindex %u, p_dst[dindex] %f\n",
|
||||
threadIdx.x,
|
||||
sindex,
|
||||
p_src[sindex],
|
||||
dindex,
|
||||
p_dst[dindex]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user