Merge commit '3052d7c9e6972d5ea7d2225ab78e45554ba70efd' into develop

This commit is contained in:
assistant-librarian[bot]
2025-10-29 08:15:15 +00:00
parent e571490afc
commit 6f6c855c0e
13 changed files with 860 additions and 99 deletions

View File

@@ -38,7 +38,10 @@ auto create_args(int argc, char* argv[])
return std::make_tuple(result, arg_parser);
}
template <typename InDataType, typename OutDataType, typename ComputeDataType>
template <typename InDataType,
typename OutDataType,
typename ComputeDataType,
typename IndexDataType>
bool run(const ck_tile::ArgParser& arg_parser)
{
@@ -84,6 +87,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
constexpr bool OutputIndex = true;
constexpr bool PropagateNan = false;
// Shapes / strides / parameters (NDHWC)
const auto input_shape = ck_tile::make_tuple(N, D, H, W, C);
const auto output_shape = ck_tile::make_tuple(N, Do, Ho, Wo, C);
@@ -100,11 +106,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
{Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, 1});
ck_tile::HostTensor<OutDataType> out_ref({N, Do, Ho, Wo, C},
{Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, 1});
ck_tile::HostTensor<IndexDataType> out_index({N, Do, Ho, Wo, C},
{Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, 1});
ck_tile::HostTensor<IndexDataType> out_ref_index({N, Do, Ho, Wo, C},
{Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, 1});
ck_tile::FillUniformDistribution<InDataType>{-5.f, 5.f}(in);
ck_tile::DeviceMem in_buf(in.get_element_space_size_in_bytes());
ck_tile::DeviceMem out_buf(out.get_element_space_size_in_bytes());
ck_tile::DeviceMem out_index_buf(OutputIndex ? out_index.get_element_space_size_in_bytes() : 0);
in_buf.ToDevice(in.data());
@@ -118,10 +129,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
using Problem = ck_tile::PoolProblem<InDataType,
OutDataType,
ComputeDataType,
OutDataType,
IndexDataType,
ReduceOp,
false,
false,
OutputIndex,
PropagateNan,
Shape>;
using Kernel = ck_tile::PoolKernel<Problem>;
@@ -131,6 +142,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
auto host_args = ck_tile::PoolHostArgs<decltype(input_shape), decltype(window_spatial_lengths)>{
static_cast<InDataType*>(in_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_buf.GetDeviceBuffer()),
OutputIndex ? static_cast<IndexDataType*>(out_index_buf.GetDeviceBuffer()) : nullptr,
input_shape,
output_shape,
input_strides,
@@ -167,12 +179,28 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(do_validation)
{
ck_tile::reference_pool3d<InDataType, ComputeDataType, OutDataType>(
in, out_ref, kernel_args, ReduceOp{});
out_buf.FromDevice(out.mData.data());
pass = ck_tile::check_err(out, out_ref);
std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl;
ck_tile::reference_pool3d<InDataType,
ComputeDataType,
OutDataType,
IndexDataType,
ReduceOp,
decltype(input_shape),
decltype(window_spatial_lengths),
OutputIndex>(in, out_ref, out_ref_index, kernel_args, ReduceOp{});
if constexpr(OutputIndex)
{
out_index_buf.FromDevice(out_index.mData.data());
pass = ck_tile::check_err(out, out_ref) && ck_tile::check_err(out_index, out_ref_index);
}
else
{
pass = ck_tile::check_err(out, out_ref);
}
std::cout << "valid:" << (pass ? "y" : "n") << std::endl;
}
return pass;
@@ -184,5 +212,5 @@ int main(int argc, char* argv[])
if(!result)
return -1;
return run<ck_tile::half_t, ck_tile::half_t, float>(arg_parser) ? 0 : -2;
return run<ck_tile::half_t, ck_tile::half_t, float, ck_tile::index_t>(arg_parser) ? 0 : -2;
}