This commit is contained in:
Chao Liu
2019-03-06 12:34:31 -06:00
parent 04c5527d07
commit 8edbc659b8
2 changed files with 11 additions and 45 deletions

View File

@@ -10,8 +10,7 @@ __device__ void threadwise_6d_tensor_copy(SrcDesc,
SrcOpLengths,
Number<DataPerRead>)
{
using Float2 = float2;
using Float4 = float4;
using vector_t = typename vector_type<Float, DataPerRead>::type;
static_assert(SrcDesc{}.GetDimension() == 6 && DstDesc{}.GetDimension() == 6 &&
SrcOpLengths::nDim == 6,
@@ -62,24 +61,8 @@ __device__ void threadwise_6d_tensor_copy(SrcDesc,
const unsigned dst_index = dst_desc.Get1dIndex(
did0, did1, did2, did3, did4, iloop_d5 * DataPerRead);
if(DataPerRead == 1)
{
p_dst[dst_index] = p_src[src_index];
}
else if(DataPerRead == 2)
{
*(reinterpret_cast<Float2*>(p_dst + dst_index)) =
*(reinterpret_cast<const Float2*>(p_src + src_index));
}
else if(DataPerRead == 4)
{
*(reinterpret_cast<Float4*>(p_dst + dst_index)) =
*(reinterpret_cast<const Float4*>(p_src + src_index));
}
else
{
assert(false);
}
*(reinterpret_cast<vector_t*>(p_dst + dst_index)) =
*(reinterpret_cast<const vector_t*>(p_src + src_index));
}
}
}
@@ -97,8 +80,7 @@ __device__ void threadwise_8d_tensor_copy(SrcDesc,
SrcOpLengths,
Number<DataPerRead>)
{
using Float2 = float2;
using Float4 = float4;
using vector_t = typename vector_type<Float, DataPerRead>::type;
static_assert(SrcDesc{}.GetDimension() == 8 && DstDesc{}.GetDimension() == 8 &&
SrcOpLengths::nDim == 8,
@@ -169,24 +151,8 @@ __device__ void threadwise_8d_tensor_copy(SrcDesc,
did6,
iloop_d7 * DataPerRead);
if(DataPerRead == 1)
{
p_dst[dst_index] = p_src[src_index];
}
else if(DataPerRead == 2)
{
*(reinterpret_cast<Float2*>(p_dst + dst_index)) =
*(reinterpret_cast<const Float2*>(p_src + src_index));
}
else if(DataPerRead == 4)
{
*(reinterpret_cast<Float4*>(p_dst + dst_index)) =
*(reinterpret_cast<const Float4*>(p_src + src_index));
}
else
{
assert(false);
}
*(reinterpret_cast<vector_t*>(p_dst + dst_index)) =
*(reinterpret_cast<const vector_t*>(p_src + src_index));
}
}
}