[ROCm/composable_kernel commit: 73480fee36]
This commit is contained in:
Chao Liu
2018-11-15 23:53:23 -06:00
parent d33b269f6d
commit 3e4752bf7e
4 changed files with 115 additions and 209 deletions

View File

@@ -5,58 +5,35 @@
#if THREADWISE_TENSOR_OP_METHOD == 0
template <class TFloat, class SrcDesc, class DstDesc, class F>
__device__ void threadwise_4d_tensor_op_in(
SrcDesc, TFloat* const __restrict__ p_src, DstDesc, TFloat* __restrict__ p_dst, F f)
__device__ void threadwise_4d_tensor_op_unary(DstDesc, TFloat* __restrict__ p_dst, F f)
{
constexpr auto I0 = Index<0>{};
constexpr auto I1 = Index<1>{};
constexpr auto I2 = Index<2>{};
constexpr auto I3 = Index<3>{};
constexpr auto src_desc = SrcDesc{};
constexpr auto dst_desc = DstDesc{};
static_assert(is_same<decltype(src_desc.GetLengths()), decltype(dst_desc.GetLengths())>::value);
#if 0
if(threadIdx.x == 0)
{
print_ConstantTensorDescriptor(src_desc);
print_ConstantTensorDescriptor(dst_desc);
}
#endif
for(unsigned did0 = 0; did0 < src_desc.GetLength(I0); ++did0)
for(unsigned did0 = 0; did0 < dst_desc.GetLength(I0); ++did0)
{
for(unsigned did1 = 0; did1 < src_desc.GetLength(I1); ++did1)
for(unsigned did1 = 0; did1 < dst_desc.GetLength(I1); ++did1)
{
for(unsigned did2 = 0; did2 < src_desc.GetLength(I2); ++did2)
for(unsigned did2 = 0; did2 < dst_desc.GetLength(I2); ++did2)
{
for(unsigned did3 = 0; did3 < src_desc.GetLength(I3); ++did3)
for(unsigned did3 = 0; did3 < dst_desc.GetLength(I3); ++did3)
{
const unsigned sindex =
src_desc.GetStride(I0) * did0 + src_desc.GetStride(I1) * did1 +
src_desc.GetStride(I2) * did2 + src_desc.GetStride(I3) * did3;
const unsigned dindex =
dst_desc.GetStride(I0) * did0 + dst_desc.GetStride(I1) * did1 +
dst_desc.GetStride(I2) * did2 + dst_desc.GetStride(I3) * did3;
f(p_src[sindex], p_dst[dindex]);
#if 0
if(threadIdx.x == 0)
{
printf("threadwise_4d_tensor_op: 1: thread id %u, \t"
"sindex %u, p_src[sindex] %f, \t"
"dindex %u, p_dst[dindex] %f\n",
threadIdx.x,
sindex,
p_src[sindex],
dindex,
p_dst[dindex]);
}
#endif
f(p_dst[dindex]);
}
}
}
@@ -64,7 +41,7 @@ __device__ void threadwise_4d_tensor_op_in(
}
template <class TFloat, class SrcDesc, class DstDesc, class F>
__device__ void threadwise_4d_tensor_op_wei(
__device__ void threadwise_4d_tensor_op_binary(
SrcDesc, TFloat* const __restrict__ p_src, DstDesc, TFloat* __restrict__ p_dst, F f)
{
constexpr auto I0 = Index<0>{};
@@ -102,79 +79,6 @@ __device__ void threadwise_4d_tensor_op_wei(
dst_desc.GetStride(I2) * did2 + dst_desc.GetStride(I3) * did3;
f(p_src[sindex], p_dst[dindex]);
#if 0
if(threadIdx.x == 0)
{
printf("threadwise_4d_tensor_op: 1: thread id %u, \t"
"sindex %u, p_src[sindex] %f, \t"
"dindex %u, p_dst[dindex] %f\n",
threadIdx.x,
sindex,
p_src[sindex],
dindex,
p_dst[dindex]);
}
#endif
}
}
}
}
}
template <class TFloat, class SrcDesc, class DstDesc, class F>
__device__ void threadwise_4d_tensor_op_out(
SrcDesc, TFloat* const __restrict__ p_src, DstDesc, TFloat* __restrict__ p_dst, F f)
{
constexpr auto I0 = Index<0>{};
constexpr auto I1 = Index<1>{};
constexpr auto I2 = Index<2>{};
constexpr auto I3 = Index<3>{};
constexpr auto src_desc = SrcDesc{};
constexpr auto dst_desc = DstDesc{};
static_assert(is_same<decltype(src_desc.GetLengths()), decltype(dst_desc.GetLengths())>::value);
#if 0
if(threadIdx.x == 0)
{
print_ConstantTensorDescriptor(src_desc);
print_ConstantTensorDescriptor(dst_desc);
}
#endif
for(unsigned did0 = 0; did0 < src_desc.GetLength(I0); ++did0)
{
for(unsigned did1 = 0; did1 < src_desc.GetLength(I1); ++did1)
{
for(unsigned did2 = 0; did2 < src_desc.GetLength(I2); ++did2)
{
for(unsigned did3 = 0; did3 < src_desc.GetLength(I3); ++did3)
{
const unsigned sindex =
src_desc.GetStride(I0) * did0 + src_desc.GetStride(I1) * did1 +
src_desc.GetStride(I2) * did2 + src_desc.GetStride(I3) * did3;
const unsigned dindex =
dst_desc.GetStride(I0) * did0 + dst_desc.GetStride(I1) * did1 +
dst_desc.GetStride(I2) * did2 + dst_desc.GetStride(I3) * did3;
f(p_src[sindex], p_dst[dindex]);
#if 0
if(threadIdx.x == 0)
{
printf("threadwise_4d_tensor_op: 1: thread id %u, \t"
"sindex %u, p_src[sindex] %f, \t"
"dindex %u, p_dst[dindex] %f\n",
threadIdx.x,
sindex,
p_src[sindex],
dindex,
p_dst[dindex]);
}
#endif
}
}
}