mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
added threadwise tensor reorder operation
This commit is contained in:
@@ -76,6 +76,8 @@ blockwise_4d_tensor_pointwise_operation_unary(DstDesc, TFloat* __restrict__ p_ds
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: in order to optimize mem access for different mem type,
|
||||
// need to write specialized version
|
||||
template <class TFloat,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
|
||||
@@ -182,7 +182,7 @@ __global__ void gridwise_direct_convolution_2(InGlobalDesc,
|
||||
for(unsigned c_thread_data = 0; c_thread_data < CPerBlock; c_thread_data += CPerThread)
|
||||
{
|
||||
// threadwise convolution
|
||||
#if 1
|
||||
#if 0
|
||||
threadwise_direct_convolution_2(
|
||||
in_thread_block_desc,
|
||||
p_in_block + in_block_desc.Get1dIndex(n_thread_data_begin,
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
#include "constant_tensor_descriptor.cuh"
|
||||
|
||||
template <class TFloat, class Desc, class F>
|
||||
__device__ void threadwise_4d_tensor_pointwise_op_unary(Desc, TFloat* __restrict__ p, F f)
|
||||
__device__ void threadwise_4d_tensor_pointwise_operation_unary(Desc, TFloat* __restrict__ p, F f)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
@@ -35,41 +35,48 @@ __device__ void threadwise_4d_tensor_pointwise_op_unary(Desc, TFloat* __restrict
|
||||
}
|
||||
}
|
||||
|
||||
template <class TFloat, class DescA, class DescB, class DescRef, class F>
|
||||
__device__ void threadwise_4d_tensor_pointwise_op_binary(
|
||||
DescA, TFloat* const __restrict__ p_a, DescB, TFloat* __restrict__ p_b, DescRef, F f)
|
||||
// TODO: in order to optimize mem access for different mem type,
|
||||
// need to write specialized version
|
||||
template <class TFloat, class SrcDesc, class DstDesc, class RefDesc, class Reorder, class F>
|
||||
__device__ void
|
||||
threadwise_4d_tensor_pointwise_operation_binary_reorder(SrcDesc,
|
||||
TFloat* const __restrict__ p_src,
|
||||
DstDesc,
|
||||
TFloat* __restrict__ p_dst,
|
||||
RefDesc,
|
||||
Reorder,
|
||||
F f)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto desc_a = DescA{};
|
||||
constexpr auto desc_b = DescB{};
|
||||
constexpr auto desc_ref = DescRef{};
|
||||
constexpr unsigned IT0 = Reorder{}.Get(I0);
|
||||
constexpr unsigned IT1 = Reorder{}.Get(I1);
|
||||
constexpr unsigned IT2 = Reorder{}.Get(I2);
|
||||
constexpr unsigned IT3 = Reorder{}.Get(I3);
|
||||
|
||||
#if 0
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(desc_a, "threadwise_4d_tensor_op_binary: desc_a: ");
|
||||
print_ConstantTensorDescriptor(desc_b, "threadwise_4d_tensor_op_binary: desc_b: ");
|
||||
print_ConstantTensorDescriptor(desc_ref, "threadwise_4d_tensor_op_binary: desc_ref: ");
|
||||
}
|
||||
#endif
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
constexpr auto ref_desc = RefDesc{};
|
||||
|
||||
for(unsigned did0 = 0; did0 < desc_ref.GetLength(I0); ++did0)
|
||||
for(unsigned did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
|
||||
{
|
||||
for(unsigned did1 = 0; did1 < desc_ref.GetLength(I1); ++did1)
|
||||
for(unsigned did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
|
||||
{
|
||||
for(unsigned did2 = 0; did2 < desc_ref.GetLength(I2); ++did2)
|
||||
for(unsigned did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
|
||||
{
|
||||
for(unsigned did3 = 0; did3 < desc_ref.GetLength(I3); ++did3)
|
||||
for(unsigned did3 = 0; did3 < ref_desc.GetLength(I3); ++did3)
|
||||
{
|
||||
const unsigned aindex = desc_a.Get1dIndex(did0, did1, did2, did3);
|
||||
const unsigned aindex = src_desc.Get1dIndex(did0, did1, did2, did3);
|
||||
|
||||
const unsigned bindex = desc_b.Get1dIndex(did0, did1, did2, did3);
|
||||
const unsigned did[4] = {did0, did1, did2, did3};
|
||||
|
||||
f(p_a[aindex], p_b[bindex]);
|
||||
const unsigned bindex =
|
||||
dst_desc.Get1dIndex(did[IT0], did[IT1], did[IT2], did[IT3]);
|
||||
|
||||
f(p_src[aindex], p_dst[bindex]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -81,18 +88,37 @@ __device__ void threadwise_4d_tensor_set_zero(Desc, TFloat* __restrict__ p)
|
||||
{
|
||||
auto f_set_zero = [](TFloat& v) { v = TFloat(0); };
|
||||
|
||||
threadwise_4d_tensor_pointwise_op_unary<TFloat, Desc, decltype(f_set_zero)>(
|
||||
threadwise_4d_tensor_pointwise_operation_unary<TFloat, Desc, decltype(f_set_zero)>(
|
||||
Desc{}, p, f_set_zero);
|
||||
}
|
||||
|
||||
template <class TFloat, class SrcDesc, class DstDesc, class RefDesc, class Reorder>
|
||||
__device__ void threadwise_4d_tensor_copy_reorder(SrcDesc,
|
||||
TFloat* const __restrict__ p_src,
|
||||
DstDesc,
|
||||
TFloat* __restrict__ p_dst,
|
||||
RefDesc,
|
||||
Reorder)
|
||||
{
|
||||
auto f_copy = [](const TFloat& src, TFloat& dst) { dst = src; };
|
||||
|
||||
threadwise_4d_tensor_pointwise_operation_binary_reorder<TFloat,
|
||||
SrcDesc,
|
||||
DstDesc,
|
||||
RefDesc,
|
||||
Reorder,
|
||||
decltype(f_copy)>(
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, RefDesc{}, Reorder{}, f_copy);
|
||||
}
|
||||
|
||||
template <class TFloat, class SrcDesc, class DstDesc, class RefDesc>
|
||||
__device__ void threadwise_4d_tensor_copy(
|
||||
SrcDesc, TFloat* const __restrict__ p_src, DstDesc, TFloat* __restrict__ p_dst, RefDesc)
|
||||
{
|
||||
auto f_copy = [](const TFloat& src, TFloat& dst) { dst = src; };
|
||||
auto reorder = Sequence<0, 1, 2, 3>{};
|
||||
|
||||
threadwise_4d_tensor_pointwise_op_binary<TFloat, SrcDesc, DstDesc, RefDesc, decltype(f_copy)>(
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, RefDesc{}, f_copy);
|
||||
threadwise_4d_tensor_copy_reorder<TFloat, SrcDesc, DstDesc, RefDesc, decltype(reorder)>(
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, RefDesc{}, reorder);
|
||||
}
|
||||
|
||||
template <class TFloat, class Desc, class IDim, class NShift>
|
||||
|
||||
Reference in New Issue
Block a user