Files
composable_kernel/tile_engine/ops/pooling/pooling_common.hpp
Aleksander Dudek 9bfcce5566 fix formating
2026-02-10 18:45:06 +00:00

48 lines
1.5 KiB
C++

// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <string>
#include <sstream>
#include <iostream>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/pooling.hpp"
namespace ck_tile {
/// @brief Kernel trait parameters for pooling tile_engine configurations
struct PoolingKernelTraits
{
std::string reduce_op; // "max" or "avg"
bool output_index; // Whether to output indices (max pooling)
bool propagate_nan; // Whether to propagate NaN values
bool cross_warp; // Whether cross-warp reduction is used
std::string to_string() const
{
std::ostringstream oss;
oss << reduce_op << "_" << (output_index ? "idx" : "noidx") << "_"
<< (propagate_nan ? "nan" : "nonan") << "_"
<< (cross_warp ? "crosswarp" : "nocrosswarp");
return oss.str();
}
};
/// @brief Extract traits from a kernel name string
inline PoolingKernelTraits extract_pooling_traits_from_name(const std::string& name)
{
PoolingKernelTraits traits;
traits.reduce_op = (name.find("max") != std::string::npos) ? "max" : "avg";
traits.output_index =
(name.find("idx") != std::string::npos) && (name.find("noidx") == std::string::npos);
traits.propagate_nan =
(name.find("nan") != std::string::npos) && (name.find("nonan") == std::string::npos);
traits.cross_warp = (name.find("crosswarp") != std::string::npos) &&
(name.find("nocrosswarp") == std::string::npos);
return traits;
}
} // namespace ck_tile