mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 12:00:07 +00:00
Merge commit '3052d7c9e6972d5ea7d2225ab78e45554ba70efd' into develop
This commit is contained in:
@@ -38,7 +38,10 @@ auto create_args(int argc, char* argv[])
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename InDataType, typename OutDataType, typename ComputeDataType>
|
||||
template <typename InDataType,
|
||||
typename OutDataType,
|
||||
typename ComputeDataType,
|
||||
typename IndexDataType>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
|
||||
@@ -84,6 +87,9 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
|
||||
constexpr bool OutputIndex = true;
|
||||
constexpr bool PropagateNan = false;
|
||||
|
||||
// Shapes / strides / parameters (NDHWC)
|
||||
const auto input_shape = ck_tile::make_tuple(N, D, H, W, C);
|
||||
const auto output_shape = ck_tile::make_tuple(N, Do, Ho, Wo, C);
|
||||
@@ -100,11 +106,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, 1});
|
||||
ck_tile::HostTensor<OutDataType> out_ref({N, Do, Ho, Wo, C},
|
||||
{Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, 1});
|
||||
ck_tile::HostTensor<IndexDataType> out_index({N, Do, Ho, Wo, C},
|
||||
{Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, 1});
|
||||
ck_tile::HostTensor<IndexDataType> out_ref_index({N, Do, Ho, Wo, C},
|
||||
{Do * Ho * Wo * C, Ho * Wo * C, Wo * C, C, 1});
|
||||
|
||||
ck_tile::FillUniformDistribution<InDataType>{-5.f, 5.f}(in);
|
||||
|
||||
ck_tile::DeviceMem in_buf(in.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem out_buf(out.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem out_index_buf(OutputIndex ? out_index.get_element_space_size_in_bytes() : 0);
|
||||
|
||||
in_buf.ToDevice(in.data());
|
||||
|
||||
@@ -118,10 +129,10 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
using Problem = ck_tile::PoolProblem<InDataType,
|
||||
OutDataType,
|
||||
ComputeDataType,
|
||||
OutDataType,
|
||||
IndexDataType,
|
||||
ReduceOp,
|
||||
false,
|
||||
false,
|
||||
OutputIndex,
|
||||
PropagateNan,
|
||||
Shape>;
|
||||
using Kernel = ck_tile::PoolKernel<Problem>;
|
||||
|
||||
@@ -131,6 +142,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
auto host_args = ck_tile::PoolHostArgs<decltype(input_shape), decltype(window_spatial_lengths)>{
|
||||
static_cast<InDataType*>(in_buf.GetDeviceBuffer()),
|
||||
static_cast<OutDataType*>(out_buf.GetDeviceBuffer()),
|
||||
OutputIndex ? static_cast<IndexDataType*>(out_index_buf.GetDeviceBuffer()) : nullptr,
|
||||
input_shape,
|
||||
output_shape,
|
||||
input_strides,
|
||||
@@ -167,12 +179,28 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
if(do_validation)
|
||||
{
|
||||
ck_tile::reference_pool3d<InDataType, ComputeDataType, OutDataType>(
|
||||
in, out_ref, kernel_args, ReduceOp{});
|
||||
out_buf.FromDevice(out.mData.data());
|
||||
pass = ck_tile::check_err(out, out_ref);
|
||||
|
||||
std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
ck_tile::reference_pool3d<InDataType,
|
||||
ComputeDataType,
|
||||
OutDataType,
|
||||
IndexDataType,
|
||||
ReduceOp,
|
||||
decltype(input_shape),
|
||||
decltype(window_spatial_lengths),
|
||||
OutputIndex>(in, out_ref, out_ref_index, kernel_args, ReduceOp{});
|
||||
|
||||
if constexpr(OutputIndex)
|
||||
{
|
||||
out_index_buf.FromDevice(out_index.mData.data());
|
||||
pass = ck_tile::check_err(out, out_ref) && ck_tile::check_err(out_index, out_ref_index);
|
||||
}
|
||||
else
|
||||
{
|
||||
pass = ck_tile::check_err(out, out_ref);
|
||||
}
|
||||
|
||||
std::cout << "valid:" << (pass ? "y" : "n") << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
@@ -184,5 +212,5 @@ int main(int argc, char* argv[])
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
return run<ck_tile::half_t, ck_tile::half_t, float>(arg_parser) ? 0 : -2;
|
||||
return run<ck_tile::half_t, ck_tile::half_t, float, ck_tile::index_t>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user