// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #pragma once #include #include #include #include #include #include #include #include #include #include "ck_tile/core.hpp" #include "ck_tile/host.hpp" // Helper function to determine if a layout is row-major template constexpr auto is_row_major(Layout) { return ck_tile::bool_constant>{}; } enum class Metric { LATENCY = 0, TFLOPS = 1, BANDWIDTH = 2 }; inline constexpr auto get_metric_name(Metric m) { switch(m) { case Metric::LATENCY: return "latency"; case Metric::TFLOPS: return "tflops"; case Metric::BANDWIDTH: return "bandwidth"; default: throw std::invalid_argument("Unsupported metric type"); } } struct PerformanceResult { double latency_; double tflops_; double bandwidth_; static bool compare(const PerformanceResult& a, const PerformanceResult& b, Metric m) { switch(m) { case Metric::LATENCY: return a.latency_ < b.latency_; case Metric::TFLOPS: return a.tflops_ > b.tflops_; case Metric::BANDWIDTH: return a.bandwidth_ > b.bandwidth_; default: throw std::invalid_argument("Unsupported metric type"); } } }; template struct KernelInstance { std::string name_; Problem problem_; PerformanceResult perf_result_; static bool compare(const KernelInstance& a, const KernelInstance& b, Metric m) { return PerformanceResult::compare(a.perf_result_, b.perf_result_, m); } }; template std::ostream& operator<<(std::ostream& os, const KernelInstance& obj) { os << "{\n" << " \"name\": \"" << obj.name_ << "\",\n" << " \"problem\": " << obj.problem_ << ",\n" << " \"perf_result\": " << obj.perf_result_ << "\n" << "}"; return os; } std::ostream& operator<<(std::ostream& os, const PerformanceResult& result) { os << "{\n" << " \"latency(ms)\": " << std::fixed << std::setprecision(2) << result.latency_ << ",\n" << " \"tflops(TFlops)\": " << result.tflops_ << ",\n" << " \"bandwidth(GB/s)\": " << result.bandwidth_ << "\n" << "}"; return os; } struct Settings { int n_warmup; int n_repeat; bool is_gpu_timer; int verify; int init_method; bool log; std::string csv_filename; bool flush_cache; int rotating_count; bool json_output; }; inline std::string get_rocm_version() { return std::to_string(HIP_VERSION_MAJOR) + "." + std::to_string(HIP_VERSION_MINOR); } template auto calculate_rtol_atol(const ck_tile::index_t K, const ck_tile::index_t kbatch, const float max_accumulated_value) { using ComputeType = std::conditional_t; // Calculate thresholds const auto rtol = ck_tile::get_relative_threshold( ck_tile::integer_divide_ceil(K, kbatch)); const auto atol = ck_tile::get_absolute_threshold( max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); // Calculate error due to split_k accumulation const auto rtol_split_k = ck_tile::get_relative_threshold(kbatch); const auto atol_split_k = ck_tile::get_absolute_threshold( max_accumulated_value, kbatch); // Use higher threshold return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); } template auto calculate_rtol_atol(const ck_tile::index_t K, const ck_tile::index_t kbatch, const float max_accumulated_value) { using ComputeTypeAB = std::conditional_t; using ComputeType = std::conditional_t; // Calculate thresholds const auto rtol = ck_tile::get_relative_threshold( ck_tile::integer_divide_ceil(K, kbatch)); const auto atol = ck_tile::get_absolute_threshold( max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); // Calculate error due to split_k accumulation const auto rtol_split_k = ck_tile::get_relative_threshold(kbatch); const auto atol_split_k = ck_tile::get_absolute_threshold( max_accumulated_value, kbatch); // Use higher threshold return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); }