Files
composable_kernel/tile_engine/ops/pooling/pool_benchmark.hpp
2025-11-27 11:31:53 +00:00

124 lines
3.3 KiB
C++

// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <iostream>
#include <string>
#include <fstream>
#include <stdexcept>
#include <iomanip>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
enum class Metric
{
LATENCY = 0,
TFLOPS = 1,
BANDWIDTH = 2
};
inline constexpr auto get_metric_name(Metric m)
{
switch(m)
{
case Metric::LATENCY: return "latency";
case Metric::TFLOPS: return "tflops";
case Metric::BANDWIDTH: return "bandwidth";
default: throw std::invalid_argument("Unsupported metric type");
}
}
struct PoolProblem
{
std::string inDType;
std::string outDType;
std::string computeDType;
std::string indexDType;
std::string blockShape;
std::string reduceOp;
bool outputIndex;
bool propagateNan;
friend std::ostream& operator<<(std::ostream& os, const PoolProblem& problem)
{
os << "{\n"
<< " \"inDType\":" << problem.inDType << ",\n"
<< " \"outDType\":" << problem.outDType << ",\n"
<< " \"computeDType\":" << problem.computeDType << ",\n"
<< " \"indexDType\":" << problem.indexDType << ",\n"
<< " \"blockShape\":" << problem.blockShape << ",\n"
<< " \"reduceOp\":" << problem.reduceOp << ",\n"
<< " \"outputIndex\":" << (problem.outputIndex ? "true" : "false") << ",\n"
<< " \"propagateNan\":" << (problem.propagateNan ? "true" : "false")
<< "\n"
<< "}";
return os;
}
}
struct PerformanceResult
{
double latency_;
double tflops_;
double bandwidth_;
static bool compare(const PerformanceResult& a, const PerformanceResult& b, Metric m)
{
switch(m)
{
case Metric::LATENCY: return a.latency_ < b.latency_;
case Metric::TFLOPS: return a.tflops_ > b.tflops_;
case Metric::BANDWIDTH: return a.bandwidth_ > b.bandwidth_;
default: throw std::invalid_argument("Unsupported metric type");
}
}
friend std::ostream& operator<<(std::ostream& os, const PerformanceResult& result)
{
os << "{\n"
<< " \"latency(ms)\": " << std::fixed << std::setprecision(2) << result.latency_
<< ",\n"
<< " \"tflops(TFlops)\": " << result.tflops_ << ",\n"
<< " \"bandwidth(GB/s)\": " << result.bandwidth_ << "\n"
<< "}";
return os;
}
};
struct KernelInstance
{
std::string name_;
PoolProblem problem_;
PerformanceResult perf_result_;
static bool compare(const KernelInstance& a, const KernelInstance& b, Metric m)
{
return PerformanceResult::compare(a.perf_result_, b.perf_result_, m);
}
friend std::ostream& operator<<(std::ostream& os, const KernelInstance& obj)
{
os << "{\n"
<< " \"name\": \"" << obj.name_ << "\",\n"
<< " \"problem\": " << obj.problem_ << ",\n"
<< " \"perf_result\": " << obj.perf_result_ << "\n"
<< "}";
return os;
}
};
inline std::string get_rocm_version()
{
std::ifstream version_file("/opt/rocm/.info/version");
if(version_file.is_open())
{
std::string version;
std::getline(version_file, version);
return version;
}
return "Unknown";
}