mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 19:09:59 +00:00
@@ -28,28 +28,28 @@ __device__ void threadwise_direct_convolution_1(InDesc,
|
||||
}
|
||||
#endif
|
||||
|
||||
for(unsigned n = 0; n < out_desc.GetLength(I0); ++n)
|
||||
for(index_t n = 0; n < out_desc.GetLength(I0); ++n)
|
||||
{
|
||||
for(unsigned k = 0; k < out_desc.GetLength(I1); ++k)
|
||||
for(index_t k = 0; k < out_desc.GetLength(I1); ++k)
|
||||
{
|
||||
for(unsigned ho = 0; ho < out_desc.GetLength(I2); ++ho)
|
||||
for(index_t ho = 0; ho < out_desc.GetLength(I2); ++ho)
|
||||
{
|
||||
for(unsigned wo = 0; wo < out_desc.GetLength(I3); ++wo)
|
||||
for(index_t wo = 0; wo < out_desc.GetLength(I3); ++wo)
|
||||
{
|
||||
for(unsigned c = 0; c < wei_desc.GetLength(I1); ++c)
|
||||
for(index_t c = 0; c < wei_desc.GetLength(I1); ++c)
|
||||
{
|
||||
for(unsigned y = 0; y < wei_desc.GetLength(I2); ++y)
|
||||
for(index_t y = 0; y < wei_desc.GetLength(I2); ++y)
|
||||
{
|
||||
for(unsigned x = 0; x < wei_desc.GetLength(I3); ++x)
|
||||
for(index_t x = 0; x < wei_desc.GetLength(I3); ++x)
|
||||
{
|
||||
const unsigned hi = ho + y;
|
||||
const unsigned wi = wo + x;
|
||||
const index_t hi = ho + y;
|
||||
const index_t wi = wo + x;
|
||||
|
||||
const unsigned in_index = in_desc.Get1dIndex(n, c, hi, wi);
|
||||
const index_t in_index = in_desc.Get1dIndex(n, c, hi, wi);
|
||||
|
||||
const unsigned wei_index = wei_desc.Get1dIndex(k, c, y, x);
|
||||
const index_t wei_index = wei_desc.Get1dIndex(k, c, y, x);
|
||||
|
||||
const unsigned out_index = out_desc.Get1dIndex(n, k, ho, wo);
|
||||
const index_t out_index = out_desc.Get1dIndex(n, k, ho, wo);
|
||||
|
||||
fused_multiply_accumulate(
|
||||
p_out[out_index], p_wei[wei_index], p_in[in_index]);
|
||||
@@ -125,7 +125,7 @@ __device__ void threadwise_direct_convolution_3(InDesc,
|
||||
Data p_in_reg[in_reg_desc.GetElementSpace()];
|
||||
Data p_wei_reg[wei_reg_desc.GetElementSpace()];
|
||||
|
||||
constexpr unsigned in_w_new_read = 1;
|
||||
constexpr index_t in_w_new_read = 1;
|
||||
|
||||
constexpr auto in_desc_reg_new_read =
|
||||
make_ConstantTensorDescriptor(Sequence<in_reg_desc.GetLength(I0),
|
||||
@@ -136,7 +136,7 @@ __device__ void threadwise_direct_convolution_3(InDesc,
|
||||
#if 0
|
||||
// this verison reused old input data in register, and read new data from LDS
|
||||
// loop over vertical direction
|
||||
for(unsigned y = 0; y < wei_desc.GetLength(I2); ++y)
|
||||
for(index_t y = 0; y < wei_desc.GetLength(I2); ++y)
|
||||
{
|
||||
// read first input
|
||||
threadwise_4d_tensor_copy(in_desc,
|
||||
@@ -157,7 +157,7 @@ __device__ void threadwise_direct_convolution_3(InDesc,
|
||||
in_reg_desc, p_in_reg, wei_reg_desc, p_wei_reg, out_desc, p_out);
|
||||
|
||||
// loop over horizontal direction
|
||||
for(unsigned x = 1; x < wei_desc.GetLength(I3); ++x)
|
||||
for(index_t x = 1; x < wei_desc.GetLength(I3); ++x)
|
||||
{
|
||||
// read new weight
|
||||
threadwise_4d_tensor_copy(wei_desc,
|
||||
@@ -186,10 +186,10 @@ __device__ void threadwise_direct_convolution_3(InDesc,
|
||||
#elif 1
|
||||
// this version read all input from LDS when filter moves
|
||||
// loop over vertical direction
|
||||
for(unsigned y = 0; y < wei_desc.GetLength(I2); ++y)
|
||||
for(index_t y = 0; y < wei_desc.GetLength(I2); ++y)
|
||||
{
|
||||
// loop over horizontal direction
|
||||
for(unsigned x = 0; x < wei_desc.GetLength(I3); ++x)
|
||||
for(index_t x = 0; x < wei_desc.GetLength(I3); ++x)
|
||||
{
|
||||
// read new weight
|
||||
threadwise_4d_tensor_copy(wei_desc,
|
||||
|
||||
Reference in New Issue
Block a user