mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
clean up
This commit is contained in:
@@ -92,47 +92,44 @@ __device__ void blockwise_convolution(InDesc,
|
||||
auto f_copy = [](const TFloat& src, TFloat& dst) { dst = src; };
|
||||
|
||||
// copy input tensor into register
|
||||
threadwise_4d_tensor_op<TFloat,
|
||||
decltype(in_thread_src_desc),
|
||||
decltype(in_thread_dst_desc),
|
||||
decltype(f_copy)>(
|
||||
threadwise_4d_tensor_op_in<TFloat,
|
||||
decltype(in_thread_src_desc),
|
||||
decltype(in_thread_dst_desc),
|
||||
decltype(f_copy)>(
|
||||
in_thread_src_desc,
|
||||
p_in + in_desc.Get1dIndex(
|
||||
n_thread_work_begin, 0, hi_thread_work_begin, wi_thread_work_begin),
|
||||
in_thread_dst_desc,
|
||||
p_in_thread,
|
||||
f_copy,
|
||||
false);
|
||||
f_copy);
|
||||
|
||||
for(unsigned k_thread_work_begin = 0; k_thread_work_begin < KPerBlock;
|
||||
++k_thread_work_begin)
|
||||
{
|
||||
// copy weight tensor into register
|
||||
threadwise_4d_tensor_op<TFloat,
|
||||
decltype(wei_thread_src_desc),
|
||||
decltype(wei_thread_dst_desc),
|
||||
decltype(f_copy)>(
|
||||
threadwise_4d_tensor_op_wei<TFloat,
|
||||
decltype(wei_thread_src_desc),
|
||||
decltype(wei_thread_dst_desc),
|
||||
decltype(f_copy)>(
|
||||
wei_thread_src_desc,
|
||||
p_wei + wei_desc.Get1dIndex(k_thread_work_begin, 0, 0, 0),
|
||||
wei_thread_dst_desc,
|
||||
p_wei_thread,
|
||||
f_copy,
|
||||
false);
|
||||
f_copy);
|
||||
|
||||
// copy output tensor into register
|
||||
threadwise_4d_tensor_op<TFloat,
|
||||
decltype(out_thread_src_desc),
|
||||
decltype(out_thread_dst_desc),
|
||||
decltype(f_copy)>(out_thread_src_desc,
|
||||
p_out +
|
||||
out_desc.Get1dIndex(n_thread_work_begin,
|
||||
k_thread_work_begin,
|
||||
ho_thread_work_begin,
|
||||
wo_thread_work_begin),
|
||||
out_thread_dst_desc,
|
||||
p_out_thread,
|
||||
f_copy,
|
||||
false);
|
||||
threadwise_4d_tensor_op_out<TFloat,
|
||||
decltype(out_thread_src_desc),
|
||||
decltype(out_thread_dst_desc),
|
||||
decltype(f_copy)>(
|
||||
out_thread_src_desc,
|
||||
p_out + out_desc.Get1dIndex(n_thread_work_begin,
|
||||
k_thread_work_begin,
|
||||
ho_thread_work_begin,
|
||||
wo_thread_work_begin),
|
||||
out_thread_dst_desc,
|
||||
p_out_thread,
|
||||
f_copy);
|
||||
|
||||
// threadwise convolution
|
||||
threadwise_direct_convolution<TFloat,
|
||||
@@ -146,19 +143,18 @@ __device__ void blockwise_convolution(InDesc,
|
||||
p_out_thread);
|
||||
|
||||
// accumulate output tensor into LDS
|
||||
threadwise_4d_tensor_op<TFloat,
|
||||
decltype(out_thread_dst_desc),
|
||||
decltype(out_thread_src_desc),
|
||||
decltype(f_copy)>(out_thread_dst_desc,
|
||||
p_out_thread,
|
||||
out_thread_src_desc,
|
||||
p_out +
|
||||
out_desc.Get1dIndex(n_thread_work_begin,
|
||||
k_thread_work_begin,
|
||||
ho_thread_work_begin,
|
||||
wo_thread_work_begin),
|
||||
f_copy,
|
||||
false);
|
||||
threadwise_4d_tensor_op_out<TFloat,
|
||||
decltype(out_thread_dst_desc),
|
||||
decltype(out_thread_src_desc),
|
||||
decltype(f_copy)>(
|
||||
out_thread_dst_desc,
|
||||
p_out_thread,
|
||||
out_thread_src_desc,
|
||||
p_out + out_desc.Get1dIndex(n_thread_work_begin,
|
||||
k_thread_work_begin,
|
||||
ho_thread_work_begin,
|
||||
wo_thread_work_begin),
|
||||
f_copy);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,12 +5,8 @@
|
||||
|
||||
#if THREADWISE_TENSOR_OP_METHOD == 0
|
||||
template <class TFloat, class SrcDesc, class DstDesc, class F>
|
||||
__device__ void threadwise_4d_tensor_op(SrcDesc,
|
||||
TFloat* const __restrict__ p_src,
|
||||
DstDesc,
|
||||
TFloat* __restrict__ p_dst,
|
||||
F f,
|
||||
bool flag = false)
|
||||
__device__ void threadwise_4d_tensor_op_in(
|
||||
SrcDesc, TFloat* const __restrict__ p_src, DstDesc, TFloat* __restrict__ p_dst, F f)
|
||||
{
|
||||
constexpr auto I0 = Index<0>{};
|
||||
constexpr auto I1 = Index<1>{};
|
||||
@@ -30,9 +26,122 @@ __device__ void threadwise_4d_tensor_op(SrcDesc,
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
if(flag && threadIdx.x != 0)
|
||||
return;
|
||||
for(unsigned did0 = 0; did0 < src_desc.GetLength(I0); ++did0)
|
||||
{
|
||||
for(unsigned did1 = 0; did1 < src_desc.GetLength(I1); ++did1)
|
||||
{
|
||||
for(unsigned did2 = 0; did2 < src_desc.GetLength(I2); ++did2)
|
||||
{
|
||||
for(unsigned did3 = 0; did3 < src_desc.GetLength(I3); ++did3)
|
||||
{
|
||||
const unsigned sindex =
|
||||
src_desc.GetStride(I0) * did0 + src_desc.GetStride(I1) * did1 +
|
||||
src_desc.GetStride(I2) * did2 + src_desc.GetStride(I3) * did3;
|
||||
|
||||
const unsigned dindex =
|
||||
dst_desc.GetStride(I0) * did0 + dst_desc.GetStride(I1) * did1 +
|
||||
dst_desc.GetStride(I2) * did2 + dst_desc.GetStride(I3) * did3;
|
||||
|
||||
f(p_src[sindex], p_dst[dindex]);
|
||||
|
||||
#if 0
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
printf("threadwise_4d_tensor_op: 1: thread id %u, \t"
|
||||
"sindex %u, p_src[sindex] %f, \t"
|
||||
"dindex %u, p_dst[dindex] %f\n",
|
||||
threadIdx.x,
|
||||
sindex,
|
||||
p_src[sindex],
|
||||
dindex,
|
||||
p_dst[dindex]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class TFloat, class SrcDesc, class DstDesc, class F>
|
||||
__device__ void threadwise_4d_tensor_op_wei(
|
||||
SrcDesc, TFloat* const __restrict__ p_src, DstDesc, TFloat* __restrict__ p_dst, F f)
|
||||
{
|
||||
constexpr auto I0 = Index<0>{};
|
||||
constexpr auto I1 = Index<1>{};
|
||||
constexpr auto I2 = Index<2>{};
|
||||
constexpr auto I3 = Index<3>{};
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
|
||||
static_assert(is_same<decltype(src_desc.GetLengths()), decltype(dst_desc.GetLengths())>::value);
|
||||
|
||||
#if 0
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(src_desc);
|
||||
print_ConstantTensorDescriptor(dst_desc);
|
||||
}
|
||||
#endif
|
||||
|
||||
for(unsigned did0 = 0; did0 < src_desc.GetLength(I0); ++did0)
|
||||
{
|
||||
for(unsigned did1 = 0; did1 < src_desc.GetLength(I1); ++did1)
|
||||
{
|
||||
for(unsigned did2 = 0; did2 < src_desc.GetLength(I2); ++did2)
|
||||
{
|
||||
for(unsigned did3 = 0; did3 < src_desc.GetLength(I3); ++did3)
|
||||
{
|
||||
const unsigned sindex =
|
||||
src_desc.GetStride(I0) * did0 + src_desc.GetStride(I1) * did1 +
|
||||
src_desc.GetStride(I2) * did2 + src_desc.GetStride(I3) * did3;
|
||||
|
||||
const unsigned dindex =
|
||||
dst_desc.GetStride(I0) * did0 + dst_desc.GetStride(I1) * did1 +
|
||||
dst_desc.GetStride(I2) * did2 + dst_desc.GetStride(I3) * did3;
|
||||
|
||||
f(p_src[sindex], p_dst[dindex]);
|
||||
|
||||
#if 0
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
printf("threadwise_4d_tensor_op: 1: thread id %u, \t"
|
||||
"sindex %u, p_src[sindex] %f, \t"
|
||||
"dindex %u, p_dst[dindex] %f\n",
|
||||
threadIdx.x,
|
||||
sindex,
|
||||
p_src[sindex],
|
||||
dindex,
|
||||
p_dst[dindex]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class TFloat, class SrcDesc, class DstDesc, class F>
|
||||
__device__ void threadwise_4d_tensor_op_out(
|
||||
SrcDesc, TFloat* const __restrict__ p_src, DstDesc, TFloat* __restrict__ p_dst, F f)
|
||||
{
|
||||
constexpr auto I0 = Index<0>{};
|
||||
constexpr auto I1 = Index<1>{};
|
||||
constexpr auto I2 = Index<2>{};
|
||||
constexpr auto I3 = Index<3>{};
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
|
||||
static_assert(is_same<decltype(src_desc.GetLengths()), decltype(dst_desc.GetLengths())>::value);
|
||||
|
||||
#if 0
|
||||
if(threadIdx.x == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(src_desc);
|
||||
print_ConstantTensorDescriptor(dst_desc);
|
||||
}
|
||||
#endif
|
||||
|
||||
for(unsigned did0 = 0; did0 < src_desc.GetLength(I0); ++did0)
|
||||
|
||||
Reference in New Issue
Block a user