Merge commit '6c2ca1211ae29802281049843d284ba1bd6511f8' into develop

This commit is contained in:
assistant-librarian[bot]
2025-10-27 18:15:18 +00:00
parent 9cdbee7709
commit d3e72e87c4
32 changed files with 2051 additions and 44 deletions

View File

@@ -7,6 +7,7 @@
#include <thread>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace ck_tile {
@@ -14,14 +15,18 @@ namespace ck_tile {
template <ck_tile::index_t NDimSpatial,
typename InDataType,
typename WeiDataType,
typename OutDataType>
typename OutDataType,
typename Elfunc = ck_tile::element_wise::PassThrough,
typename Tuple = ck_tile::tuple<>>
CK_TILE_HOST void reference_grouped_conv_fwd(const HostTensor<InDataType>& input,
const HostTensor<WeiDataType>& weight,
HostTensor<OutDataType>& output,
std::vector<ck_tile::long_index_t> conv_strides,
std::vector<ck_tile::long_index_t> conv_dilations,
std::vector<ck_tile::long_index_t> in_left_pads,
std::vector<ck_tile::long_index_t>)
std::vector<ck_tile::long_index_t>,
Elfunc elfunc = Elfunc{},
Tuple ds = {})
{
if(!(input.get_num_of_dimension() == NDimSpatial + 3 &&
weight.get_num_of_dimension() == NDimSpatial + 3 &&
@@ -52,8 +57,12 @@ CK_TILE_HOST void reference_grouped_conv_fwd(const HostTensor<InDataType>& input
}
}
}
OutDataType v_acc_converted = ck_tile::type_convert<OutDataType>(v_acc);
output(g, n, k, wo) = v_acc_converted;
if constexpr(Tuple::size() > 0)
elfunc(v_acc, v_acc, ds.at(ck_tile::number<0>{})(g, n, k, wo));
else
elfunc(v_acc, v_acc);
OutDataType v_acc_out = ck_tile::type_convert<OutDataType>(v_acc);
output(g, n, k, wo) = v_acc_out;
};
make_ParallelTensorFunctor(func,
@@ -95,8 +104,12 @@ CK_TILE_HOST void reference_grouped_conv_fwd(const HostTensor<InDataType>& input
}
}
}
OutDataType v_acc_converted = ck_tile::type_convert<OutDataType>(v_acc);
output(g, n, k, ho, wo) = v_acc_converted;
if constexpr(Tuple::size() > 0)
elfunc(v_acc, v_acc, ds.at(ck_tile::number<0>{})(g, n, k, ho, wo));
else
elfunc(v_acc, v_acc);
OutDataType v_acc_out = ck_tile::type_convert<OutDataType>(v_acc);
output(g, n, k, ho, wo) = v_acc_out;
};
make_ParallelTensorFunctor(func,
@@ -145,8 +158,12 @@ CK_TILE_HOST void reference_grouped_conv_fwd(const HostTensor<InDataType>& input
}
}
}
OutDataType v_acc_converted = ck_tile::type_convert<OutDataType>(v_acc);
output(g, n, k, d_o, ho, wo) = v_acc_converted;
if constexpr(Tuple::size() > 0)
elfunc(v_acc, v_acc, ds.at(ck_tile::number<0>{})(g, n, k, d_o, ho, wo));
else
elfunc(v_acc, v_acc);
OutDataType v_acc_out = ck_tile::type_convert<OutDataType>(v_acc);
output(g, n, k, d_o, ho, wo) = v_acc_out;
};
make_ParallelTensorFunctor(func,