mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 23:05:54 +00:00
another version of direct conv
[ROCm/composable_kernel commit: 39775d484c]
This commit is contained in:
@@ -35,41 +35,41 @@ __device__ void threadwise_4d_tensor_pointwise_op_unary(Desc, TFloat* __restrict
|
||||
}
|
||||
}
|
||||
|
||||
template <class TFloat, class SrcDesc, class DstDesc, class F>
|
||||
template <class TFloat, class DescA, class DescB, class DescRef, class F>
|
||||
__device__ void threadwise_4d_tensor_pointwise_op_binary(
|
||||
SrcDesc, TFloat* const __restrict__ p_src, DstDesc, TFloat* __restrict__ p_dst, F f)
|
||||
DescA, TFloat* const __restrict__ p_a, DescB, TFloat* __restrict__ p_b, DescRef, F f)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
|
||||
static_assert(is_same<decltype(src_desc.GetLengths()), decltype(dst_desc.GetLengths())>::value);
|
||||
constexpr auto desc_a = DescA{};
|
||||
constexpr auto desc_b = DescB{};
|
||||
constexpr auto desc_ref = DescRef{};
|
||||
|
||||
#if 0
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(src_desc, "threadwise_4d_tensor_op_binary: src_desc: ");
|
||||
print_ConstantTensorDescriptor(dst_desc, "threadwise_4d_tensor_op_binary: dst_desc: ");
|
||||
print_ConstantTensorDescriptor(desc_a, "threadwise_4d_tensor_op_binary: desc_a: ");
|
||||
print_ConstantTensorDescriptor(desc_b, "threadwise_4d_tensor_op_binary: desc_b: ");
|
||||
print_ConstantTensorDescriptor(desc_ref, "threadwise_4d_tensor_op_binary: desc_ref: ");
|
||||
}
|
||||
#endif
|
||||
|
||||
for(unsigned did0 = 0; did0 < src_desc.GetLength(I0); ++did0)
|
||||
for(unsigned did0 = 0; did0 < desc_ref.GetLength(I0); ++did0)
|
||||
{
|
||||
for(unsigned did1 = 0; did1 < src_desc.GetLength(I1); ++did1)
|
||||
for(unsigned did1 = 0; did1 < desc_ref.GetLength(I1); ++did1)
|
||||
{
|
||||
for(unsigned did2 = 0; did2 < src_desc.GetLength(I2); ++did2)
|
||||
for(unsigned did2 = 0; did2 < desc_ref.GetLength(I2); ++did2)
|
||||
{
|
||||
for(unsigned did3 = 0; did3 < src_desc.GetLength(I3); ++did3)
|
||||
for(unsigned did3 = 0; did3 < desc_ref.GetLength(I3); ++did3)
|
||||
{
|
||||
const unsigned sindex = src_desc.Get1dIndex(did0, did1, did2, did3);
|
||||
const unsigned aindex = desc_a.Get1dIndex(did0, did1, did2, did3);
|
||||
|
||||
const unsigned dindex = dst_desc.Get1dIndex(did0, did1, did2, did3);
|
||||
const unsigned bindex = desc_b.Get1dIndex(did0, did1, did2, did3);
|
||||
|
||||
f(p_src[sindex], p_dst[dindex]);
|
||||
f(p_a[aindex], p_b[bindex]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -85,20 +85,18 @@ __device__ void threadwise_4d_tensor_set_zero(Desc, TFloat* __restrict__ p)
|
||||
Desc{}, p, f_set_zero);
|
||||
}
|
||||
|
||||
template <class TFloat, class SrcDesc, class DstDesc>
|
||||
__device__ void threadwise_4d_tensor_copy(SrcDesc,
|
||||
TFloat* const __restrict__ p_src,
|
||||
DstDesc,
|
||||
TFloat* __restrict__ p_dst)
|
||||
template <class TFloat, class SrcDesc, class DstDesc, class RefDesc>
|
||||
__device__ void threadwise_4d_tensor_copy(
|
||||
SrcDesc, TFloat* const __restrict__ p_src, DstDesc, TFloat* __restrict__ p_dst, RefDesc)
|
||||
{
|
||||
auto f_copy = [](const TFloat& src, TFloat& dst) { dst = src; };
|
||||
|
||||
threadwise_4d_tensor_pointwise_op_binary<TFloat, SrcDesc, DstDesc, decltype(f_copy)>(
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, f_copy);
|
||||
threadwise_4d_tensor_pointwise_op_binary<TFloat, SrcDesc, DstDesc, RefDesc, decltype(f_copy)>(
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, RefDesc{}, f_copy);
|
||||
}
|
||||
|
||||
template <class TFloat, class Desc, class IDim>
|
||||
__device__ void threadwise_4d_tensor_shift_down(Desc, TFloat* __restrict__ p, IDim, unsigned shift)
|
||||
template <class TFloat, class Desc, class IDim, class NShift>
|
||||
__device__ void threadwise_4d_tensor_shift_down(Desc, TFloat* __restrict__ p, IDim, NShift)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
@@ -114,17 +112,19 @@ __device__ void threadwise_4d_tensor_shift_down(Desc, TFloat* __restrict__ p, ID
|
||||
}
|
||||
#endif
|
||||
|
||||
const unsigned did0_end =
|
||||
is_same<decltype(I0), IDim>::value ? desc.GetLength(I0) - shift : desc.GetLength(I0);
|
||||
constexpr unsigned nshift = NShift::mValue;
|
||||
|
||||
const unsigned did1_end =
|
||||
is_same<decltype(I1), IDim>::value ? desc.GetLength(I1) - shift : desc.GetLength(I1);
|
||||
constexpr unsigned did0_end =
|
||||
is_same<decltype(I0), IDim>::value ? desc.GetLength(I0) - nshift : desc.GetLength(I0);
|
||||
|
||||
const unsigned did2_end =
|
||||
is_same<decltype(I2), IDim>::value ? desc.GetLength(I2) - shift : desc.GetLength(I2);
|
||||
constexpr unsigned did1_end =
|
||||
is_same<decltype(I1), IDim>::value ? desc.GetLength(I1) - nshift : desc.GetLength(I1);
|
||||
|
||||
const unsigned did3_end =
|
||||
is_same<decltype(I3), IDim>::value ? desc.GetLength(I3) - shift : desc.GetLength(I3);
|
||||
constexpr unsigned did2_end =
|
||||
is_same<decltype(I2), IDim>::value ? desc.GetLength(I2) - nshift : desc.GetLength(I2);
|
||||
|
||||
constexpr unsigned did3_end =
|
||||
is_same<decltype(I3), IDim>::value ? desc.GetLength(I3) - nshift : desc.GetLength(I3);
|
||||
|
||||
for(unsigned did0 = 0; did0 < did0_end; ++did0)
|
||||
{
|
||||
@@ -136,11 +136,11 @@ __device__ void threadwise_4d_tensor_shift_down(Desc, TFloat* __restrict__ p, ID
|
||||
{
|
||||
const unsigned dindex = desc.Get1dIndex(did0, did1, did2, did3);
|
||||
|
||||
const unsigned sindex = dindex + shift * desc.GetStride(IDim{});
|
||||
const unsigned sindex = dindex + nshift * desc.GetStride(IDim{});
|
||||
|
||||
p[dindex] = p[sindex];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user