mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 04:49:54 +00:00
@@ -1,50 +1,34 @@
|
||||
#pragma once
|
||||
#include "constant_tensor_descriptor.cuh"
|
||||
|
||||
#define THREADWISE_TENSOR_OP_METHOD 0
|
||||
|
||||
#if THREADWISE_TENSOR_OP_METHOD == 0
|
||||
template <class TFloat, class DstDesc, class F>
|
||||
__device__ void threadwise_4d_tensor_op_unary(DstDesc, TFloat* __restrict__ p_dst, F f)
|
||||
template <class TFloat, class Desc, class F>
|
||||
__device__ void threadwise_4d_tensor_pointwise_op_unary(Desc, TFloat* __restrict__ p_dst, F f)
|
||||
{
|
||||
constexpr auto I0 = Index<0>{};
|
||||
constexpr auto I1 = Index<1>{};
|
||||
constexpr auto I2 = Index<2>{};
|
||||
constexpr auto I3 = Index<3>{};
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
constexpr auto desc = Desc{};
|
||||
|
||||
#if 0
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(dst_desc, "threadwise_4d_tensor_op_unary: ");
|
||||
print_ConstantTensorDescriptor(desc, "threadwise_4d_tensor_op_unary: ");
|
||||
}
|
||||
#endif
|
||||
|
||||
for(unsigned did0 = 0; did0 < dst_desc.GetLength(I0); ++did0)
|
||||
for(unsigned did0 = 0; did0 < desc.GetLength(I0); ++did0)
|
||||
{
|
||||
for(unsigned did1 = 0; did1 < dst_desc.GetLength(I1); ++did1)
|
||||
for(unsigned did1 = 0; did1 < desc.GetLength(I1); ++did1)
|
||||
{
|
||||
for(unsigned did2 = 0; did2 < dst_desc.GetLength(I2); ++did2)
|
||||
for(unsigned did2 = 0; did2 < desc.GetLength(I2); ++did2)
|
||||
{
|
||||
for(unsigned did3 = 0; did3 < dst_desc.GetLength(I3); ++did3)
|
||||
for(unsigned did3 = 0; did3 < desc.GetLength(I3); ++did3)
|
||||
{
|
||||
const unsigned dindex =
|
||||
dst_desc.GetStride(I0) * did0 + dst_desc.GetStride(I1) * did1 +
|
||||
dst_desc.GetStride(I2) * did2 + dst_desc.GetStride(I3) * did3;
|
||||
const unsigned dindex = desc.Get1dIndex(did0, did1, did2, did3);
|
||||
|
||||
f(p_dst[dindex]);
|
||||
|
||||
#if 0
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
printf("threadwise_4d_tensor_op_unary: thread id %u, \t"
|
||||
"dindex %u, p_dst[dindex] %f\n",
|
||||
threadIdx.x,
|
||||
dindex,
|
||||
p_dst[dindex]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -52,13 +36,13 @@ __device__ void threadwise_4d_tensor_op_unary(DstDesc, TFloat* __restrict__ p_ds
|
||||
}
|
||||
|
||||
template <class TFloat, class SrcDesc, class DstDesc, class F>
|
||||
__device__ void threadwise_4d_tensor_op_binary(
|
||||
__device__ void threadwise_4d_tensor_pointwise_op_binary(
|
||||
SrcDesc, TFloat* const __restrict__ p_src, DstDesc, TFloat* __restrict__ p_dst, F f)
|
||||
{
|
||||
constexpr auto I0 = Index<0>{};
|
||||
constexpr auto I1 = Index<1>{};
|
||||
constexpr auto I2 = Index<2>{};
|
||||
constexpr auto I3 = Index<3>{};
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
@@ -81,99 +65,34 @@ __device__ void threadwise_4d_tensor_op_binary(
|
||||
{
|
||||
for(unsigned did3 = 0; did3 < src_desc.GetLength(I3); ++did3)
|
||||
{
|
||||
const unsigned sindex =
|
||||
src_desc.GetStride(I0) * did0 + src_desc.GetStride(I1) * did1 +
|
||||
src_desc.GetStride(I2) * did2 + src_desc.GetStride(I3) * did3;
|
||||
const unsigned sindex = src_desc.Get1dIndex(did0, did1, did2, did3);
|
||||
|
||||
const unsigned dindex =
|
||||
dst_desc.GetStride(I0) * did0 + dst_desc.GetStride(I1) * did1 +
|
||||
dst_desc.GetStride(I2) * did2 + dst_desc.GetStride(I3) * did3;
|
||||
const unsigned dindex = dst_desc.Get1dIndex(did0, did1, did2, did3);
|
||||
|
||||
f(p_src[sindex], p_dst[dindex]);
|
||||
|
||||
#if 0
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
printf("threadwise_4d_tensor_op_binary: thread id %u, \t"
|
||||
"sindex %u, p_src[sindex] %f, \t"
|
||||
"dindex %u, p_dst[dindex] %f\n",
|
||||
threadIdx.x,
|
||||
sindex,
|
||||
p_src[sindex],
|
||||
dindex,
|
||||
p_dst[dindex]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if THREADWISE_TENSOR_OP_METHOD == 1
|
||||
template <class TFloat, class SrcDesc, class DstDesc, class F>
|
||||
__device__ void threadwise_4d_tensor_op(
|
||||
SrcDesc, TFloat* const __restrict__ p_src, DstDesc, TFloat* __restrict__ p_dst, F f)
|
||||
template <class TFloat, class Desc>
|
||||
__device__ void threadwise_4d_tensor_set_zero(Desc, TFloat* __restrict__ p_dst)
|
||||
{
|
||||
constexpr auto I0 = Index<0>{};
|
||||
constexpr auto I1 = Index<1>{};
|
||||
constexpr auto I2 = Index<2>{};
|
||||
constexpr auto I3 = Index<3>{};
|
||||
auto f_set_zero = [](TFloat& v) { v = TFloat(0); };
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
|
||||
static_assert(is_same<decltype(src_desc.GetLengths()), decltype(dst_desc.GetLengths())>::value);
|
||||
|
||||
#if 0
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(src_desc, "threadwise_4d_tensor_op: src_desc: ");
|
||||
print_ConstantTensorDescriptor(dst_desc, "threadwise_4d_tensor_op: dst_desc: ");
|
||||
}
|
||||
#endif
|
||||
|
||||
unsigned sindex = 0;
|
||||
unsigned dindex = 0;
|
||||
|
||||
for(unsigned did0 = 0; did0 < src_desc.GetLength(I0); ++did0)
|
||||
{
|
||||
for(unsigned did1 = 0; did1 < src_desc.GetLength(I1); ++did1)
|
||||
{
|
||||
for(unsigned did2 = 0; did2 < src_desc.GetLength(I2); ++did2)
|
||||
{
|
||||
for(unsigned did3 = 0; did3 < src_desc.GetLength(I3); ++did3)
|
||||
{
|
||||
f(p_src[sindex], p_dst[dindex]);
|
||||
|
||||
#if 0
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
printf("threadwise_4d_tensor_op: 1: thread id %u, \t"
|
||||
"sindex %u, p_src[sindex] %f, \t"
|
||||
"dindex %u, p_dst[dindex] %f\n",
|
||||
threadIdx.x,
|
||||
sindex,
|
||||
p_src[sindex],
|
||||
dindex,
|
||||
p_dst[dindex]);
|
||||
}
|
||||
#endif
|
||||
sindex += src_desc.GetStride(I3);
|
||||
dindex += dst_desc.GetStride(I3);
|
||||
}
|
||||
|
||||
sindex += src_desc.GetStride(I2) - src_desc.GetLength(I3) * src_desc.GetStride(I3);
|
||||
dindex += dst_desc.GetStride(I2) - dst_desc.GetLength(I3) * dst_desc.GetStride(I3);
|
||||
}
|
||||
|
||||
sindex += src_desc.GetStride(I1) - src_desc.GetLength(I2) * src_desc.GetStride(I2);
|
||||
dindex += dst_desc.GetStride(I1) - dst_desc.GetLength(I2) * dst_desc.GetStride(I2);
|
||||
}
|
||||
|
||||
sindex += src_desc.GetStride(I0) - src_desc.GetLength(I1) * src_desc.GetStride(I1);
|
||||
dindex += dst_desc.GetStride(I0) - dst_desc.GetLength(I1) * dst_desc.GetStride(I1);
|
||||
}
|
||||
threadwise_4d_tensor_pointwise_op_unary<TFloat, Desc, decltype(f_set_zero)>(
|
||||
Desc{}, p_dst, f_set_zero);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <class TFloat, class SrcDesc, class DstDesc>
|
||||
__device__ void threadwise_4d_tensor_copy(SrcDesc,
|
||||
TFloat* const __restrict__ p_src,
|
||||
DstDesc,
|
||||
TFloat* __restrict__ p_dst)
|
||||
{
|
||||
auto f_copy = [](const TFloat& src, TFloat& dst) { dst = src; };
|
||||
|
||||
threadwise_4d_tensor_pointwise_op_binary<TFloat, SrcDesc, DstDesc, decltype(f_copy)>(
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, f_copy);
|
||||
}
|
||||
Reference in New Issue
Block a user