mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
refactor
This commit is contained in:
@@ -39,8 +39,8 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn(InDesc,
|
||||
|
||||
Tensor<T> wei_csrk(make_TensorDescriptor(wei_csrk_desc));
|
||||
|
||||
auto f_reorder_kcsr2csrk = [&](auto k, auto c, auto s, auto r) {
|
||||
wei_csrk(c, s, r, k) = wei_kcsr(k, c, s, r);
|
||||
auto f_reorder_kcsr2csrk = [&](auto k, auto c, auto y, auto x) {
|
||||
wei_csrk(c, y, x, k) = wei_kcsr(k, c, y, x);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_reorder_kcsr2csrk, K, C, Y, X)(
|
||||
|
||||
@@ -41,8 +41,8 @@ void device_implicit_gemm_convolution_1_chwn_csrk_khwn_padded(InDesc,
|
||||
|
||||
Tensor<T> wei_csrk(make_TensorDescriptor(wei_csrk_desc));
|
||||
|
||||
auto f_reorder_kcsr2csrk = [&](auto k, auto c, auto s, auto r) {
|
||||
wei_csrk(c, s, r, k) = wei_kcsr(k, c, s, r);
|
||||
auto f_reorder_kcsr2csrk = [&](auto k, auto c, auto y, auto x) {
|
||||
wei_csrk(c, y, x, k) = wei_kcsr(k, c, y, x);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_reorder_kcsr2csrk, K, C, Y, X)(
|
||||
|
||||
@@ -55,7 +55,7 @@ void device_implicit_gemm_convolution_2_chwn_csrk_khwn(InDesc,
|
||||
Tensor<T> wei_csrk(make_TensorDescriptor(wei_csrk_desc));
|
||||
|
||||
make_ParallelTensorFunctor(
|
||||
[&](auto k, auto c, auto s, auto r) { wei_csrk(c, s, r, k) = wei_kcsr(k, c, s, r); },
|
||||
[&](auto k, auto c, auto y, auto x) { wei_csrk(c, y, x, k) = wei_kcsr(k, c, y, x); },
|
||||
K,
|
||||
C,
|
||||
Y,
|
||||
|
||||
@@ -204,12 +204,12 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(const Float* const __restric
|
||||
__syncthreads();
|
||||
|
||||
// a series of batched GEMM
|
||||
for(unsigned s = 0; s < Y; ++s)
|
||||
for(unsigned y = 0; y < Y; ++y)
|
||||
{
|
||||
for(unsigned r = 0; r < X; ++r)
|
||||
for(unsigned x = 0; x < X; ++x)
|
||||
{
|
||||
blockwise_batch_gemm.Run(p_wei_block + wei_csrk_block_desc.Get1dIndex(0, s, r, 0),
|
||||
p_in_block + in_chwn_block_desc.Get1dIndex(0, s, r, 0),
|
||||
blockwise_batch_gemm.Run(p_wei_block + wei_csrk_block_desc.Get1dIndex(0, y, x, 0),
|
||||
p_in_block + in_chwn_block_desc.Get1dIndex(0, y, x, 0),
|
||||
p_out_thread,
|
||||
[](auto& acc, const auto&& v) { acc += v; });
|
||||
}
|
||||
|
||||
@@ -245,14 +245,14 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded(
|
||||
__syncthreads();
|
||||
|
||||
// a series of batched GEMM
|
||||
for(unsigned s = 0; s < Y; ++s)
|
||||
for(unsigned y = 0; y < Y; ++y)
|
||||
{
|
||||
for(unsigned r = 0; r < X; ++r)
|
||||
for(unsigned x = 0; x < X; ++x)
|
||||
{
|
||||
auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
|
||||
|
||||
blockwise_batch_gemm.Run(p_wei_block + wei_csrk_block_desc.Get1dIndex(0, s, r, 0),
|
||||
p_in_block + in_chwn_block_desc.Get1dIndex(0, s, r, 0),
|
||||
blockwise_batch_gemm.Run(p_wei_block + wei_csrk_block_desc.Get1dIndex(0, y, x, 0),
|
||||
p_in_block + in_chwn_block_desc.Get1dIndex(0, y, x, 0),
|
||||
p_out_thread,
|
||||
f_accum);
|
||||
}
|
||||
|
||||
@@ -275,9 +275,9 @@ __global__ void gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_b
|
||||
|
||||
// compute on current data
|
||||
// a series of GEMM
|
||||
for(unsigned s = 0; s < Y; ++s)
|
||||
for(unsigned y = 0; y < Y; ++y)
|
||||
{
|
||||
for(unsigned r = 0; r < X; ++r)
|
||||
for(unsigned x = 0; x < X; ++x)
|
||||
{
|
||||
auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
|
||||
#if 1
|
||||
@@ -285,8 +285,8 @@ __global__ void gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_b
|
||||
#else
|
||||
blockwise_gemm.Run_RegisterDoubleBuffer
|
||||
#endif
|
||||
(p_wei_block_now + wei_csrk_block_desc.Get1dIndex(0, s, r, 0),
|
||||
p_in_block_now + s * Wi + r,
|
||||
(p_wei_block_now + wei_csrk_block_desc.Get1dIndex(0, y, x, 0),
|
||||
p_in_block_now + y * Wi + x,
|
||||
p_out_thread,
|
||||
f_accum);
|
||||
}
|
||||
@@ -305,9 +305,9 @@ __global__ void gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_b
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for(unsigned s = 0; s < Y; ++s)
|
||||
for(unsigned y = 0; y < Y; ++y)
|
||||
{
|
||||
for(unsigned r = 0; r < X; ++r)
|
||||
for(unsigned x = 0; x < X; ++x)
|
||||
{
|
||||
auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
|
||||
#if 0
|
||||
@@ -315,8 +315,8 @@ __global__ void gridwise_implicit_gemm_convolution_2_chwn_csrk_khwn_lds_double_b
|
||||
#else
|
||||
blockwise_gemm.Run_RegisterDoubleBuffer
|
||||
#endif
|
||||
(p_wei_block_now + wei_csrk_block_desc.Get1dIndex(0, s, r, 0),
|
||||
p_in_block_now + s * Wi + r,
|
||||
(p_wei_block_now + wei_csrk_block_desc.Get1dIndex(0, y, x, 0),
|
||||
p_in_block_now + y * Wi + x,
|
||||
p_out_thread,
|
||||
f_accum);
|
||||
}
|
||||
|
||||
@@ -8,16 +8,16 @@
|
||||
#include <iostream>
|
||||
|
||||
template <class Range>
|
||||
std::ostream& LogRange(std::ostream& os, Range&& r, std::string delim)
|
||||
std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim)
|
||||
{
|
||||
bool first = true;
|
||||
for(auto&& x : r)
|
||||
for(auto&& v : range)
|
||||
{
|
||||
if(first)
|
||||
first = false;
|
||||
else
|
||||
os << delim;
|
||||
os << x;
|
||||
os << v;
|
||||
}
|
||||
return os;
|
||||
}
|
||||
|
||||
@@ -38,16 +38,16 @@ __device__ void threadwise_direct_convolution_1(InDesc,
|
||||
{
|
||||
for(unsigned c = 0; c < wei_desc.GetLength(I1); ++c)
|
||||
{
|
||||
for(unsigned s = 0; s < wei_desc.GetLength(I2); ++s)
|
||||
for(unsigned y = 0; y < wei_desc.GetLength(I2); ++y)
|
||||
{
|
||||
for(unsigned r = 0; r < wei_desc.GetLength(I3); ++r)
|
||||
for(unsigned x = 0; x < wei_desc.GetLength(I3); ++x)
|
||||
{
|
||||
const unsigned hi = ho + s;
|
||||
const unsigned wi = wo + r;
|
||||
const unsigned hi = ho + y;
|
||||
const unsigned wi = wo + x;
|
||||
|
||||
const unsigned in_index = in_desc.Get1dIndex(n, c, hi, wi);
|
||||
|
||||
const unsigned wei_index = wei_desc.Get1dIndex(k, c, s, r);
|
||||
const unsigned wei_index = wei_desc.Get1dIndex(k, c, y, x);
|
||||
|
||||
const unsigned out_index = out_desc.Get1dIndex(n, k, ho, wo);
|
||||
|
||||
@@ -153,18 +153,18 @@ __device__ void threadwise_direct_convolution_3(InDesc,
|
||||
#if 0
|
||||
// this verison reused old input data in register, and read new data from LDS
|
||||
// loop over vertical direction
|
||||
for(unsigned s = 0; s < wei_desc.GetLength(I2); ++s)
|
||||
for(unsigned y = 0; y < wei_desc.GetLength(I2); ++y)
|
||||
{
|
||||
// read first input
|
||||
threadwise_4d_tensor_copy(in_desc,
|
||||
p_in + in_desc.Get1dIndex(0, 0, s, 0),
|
||||
p_in + in_desc.Get1dIndex(0, 0, y, 0),
|
||||
in_reg_desc,
|
||||
p_in_reg,
|
||||
in_reg_desc.GetLengths());
|
||||
|
||||
// read first 1x1 weight
|
||||
threadwise_4d_tensor_copy(wei_desc,
|
||||
p_wei + wei_desc.Get1dIndex(0, 0, s, 0),
|
||||
p_wei + wei_desc.Get1dIndex(0, 0, y, 0),
|
||||
wei_reg_desc,
|
||||
p_wei_reg,
|
||||
wei_reg_desc.GetLengths());
|
||||
@@ -174,11 +174,11 @@ __device__ void threadwise_direct_convolution_3(InDesc,
|
||||
in_reg_desc, p_in_reg, wei_reg_desc, p_wei_reg, out_desc, p_out);
|
||||
|
||||
// loop over horizontal direction
|
||||
for(unsigned r = 1; r < wei_desc.GetLength(I3); ++r)
|
||||
for(unsigned x = 1; x < wei_desc.GetLength(I3); ++x)
|
||||
{
|
||||
// read new weight
|
||||
threadwise_4d_tensor_copy(wei_desc,
|
||||
p_wei + wei_desc.Get1dIndex(0, 0, s, r),
|
||||
p_wei + wei_desc.Get1dIndex(0, 0, y, x),
|
||||
wei_reg_desc,
|
||||
p_wei_reg,
|
||||
wei_reg_desc.GetLengths());
|
||||
@@ -189,7 +189,7 @@ __device__ void threadwise_direct_convolution_3(InDesc,
|
||||
// read new input
|
||||
threadwise_4d_tensor_copy(
|
||||
in_desc,
|
||||
p_in + in_desc.Get1dIndex(0, 0, s, r + in_reg_desc.GetLength(I3) - 1),
|
||||
p_in + in_desc.Get1dIndex(0, 0, y, x + in_reg_desc.GetLength(I3) - 1),
|
||||
in_reg_desc,
|
||||
p_in_reg +
|
||||
in_reg_desc.Get1dIndex(0, 0, 0, in_reg_desc.GetLength(I3) - in_w_new_read),
|
||||
@@ -203,21 +203,21 @@ __device__ void threadwise_direct_convolution_3(InDesc,
|
||||
#elif 1
|
||||
// this version read all input from LDS when filter moves
|
||||
// loop over vertical direction
|
||||
for(unsigned s = 0; s < wei_desc.GetLength(I2); ++s)
|
||||
for(unsigned y = 0; y < wei_desc.GetLength(I2); ++y)
|
||||
{
|
||||
// loop over horizontal direction
|
||||
for(unsigned r = 0; r < wei_desc.GetLength(I3); ++r)
|
||||
for(unsigned x = 0; x < wei_desc.GetLength(I3); ++x)
|
||||
{
|
||||
// read new weight
|
||||
threadwise_4d_tensor_copy(wei_desc,
|
||||
p_wei + wei_desc.Get1dIndex(0, 0, s, r),
|
||||
p_wei + wei_desc.Get1dIndex(0, 0, y, x),
|
||||
wei_reg_desc,
|
||||
p_wei_reg,
|
||||
wei_reg_desc.GetLengths());
|
||||
|
||||
// read new input
|
||||
threadwise_4d_tensor_copy(in_desc,
|
||||
p_in + in_desc.Get1dIndex(0, 0, s, r),
|
||||
p_in + in_desc.Get1dIndex(0, 0, y, x),
|
||||
in_reg_desc,
|
||||
p_in_reg,
|
||||
in_reg_desc.GetLengths());
|
||||
|
||||
Reference in New Issue
Block a user