#pragma once #include "ck_tile/core.hpp" #include "ck_tile/host.hpp" #include "gemm_preshuffle_common.hpp" //[TODO] Move parts of this File to commons enum class Metric { LATENCY = 0, TFLOPS = 1, BANDWIDTH = 2 }; inline constexpr auto get_metric_name(Metric m) { switch(m) { case Metric::LATENCY: return "latency"; case Metric::TFLOPS: return "tflops"; case Metric::BANDWIDTH: return "bandwidth"; default: throw std::invalid_argument("Unsupported metric type"); } } struct KernelConfig { std::tuple tile_dims; std::tuple warp_dims; std::tuple warp_tile_dims; bool permuteN; }; struct GemmProblem { int split_k_; int m_, n_, k_; int stride_a_, stride_b_, stride_c_; std::string dtype_a_, dtype_b_, dtype_acc_, dtype_c_; std::string layout_a_, layout_b_, layout_c_; bool structured_sparsity_; friend std::ostream& operator<<(std::ostream& os, const GemmProblem& problem) { os << "{\n" << " \"split_k\":" << problem.split_k_ << ",\n" << " \"m\":" << problem.m_ << ",\n" << " \"n\":" << problem.n_ << ",\n" << " \"k\":" << problem.k_ << ",\n" << " \"stride_a\":" << problem.stride_a_ << ",\n" << " \"stride_b\":" << problem.stride_b_ << ",\n" << " \"stride_c\":" << problem.stride_c_ << ",\n" << " \"dtype_a\":\"" << problem.dtype_a_ << "\",\n" << " \"dtype_b\":\"" << problem.dtype_b_ << "\",\n" << " \"dtype_acc\":\"" << problem.dtype_acc_ << "\",\n" << " \"dtype_c\":\"" << problem.dtype_c_ << "\",\n" << " \"layout_a\":\"" << problem.layout_a_ << "\",\n" << " \"layout_b\":\"" << problem.layout_b_ << "\",\n" << " \"layout_c\":\"" << problem.layout_c_ << "\",\n" << " \"structured_sparsity\":" << (problem.structured_sparsity_ ? "true" : "false") << "\n" << "}"; return os; } }; struct PerformanceResult { double latency_; double tflops_; double bandwidth_; static bool compare(const PerformanceResult& a, const PerformanceResult& b, Metric m) { switch(m) { case Metric::LATENCY: return a.latency_ < b.latency_; case Metric::TFLOPS: return a.tflops_ > b.tflops_; case Metric::BANDWIDTH: return a.bandwidth_ > b.bandwidth_; default: throw std::invalid_argument("Unsupported metric type"); } } friend std::ostream& operator<<(std::ostream& os, const PerformanceResult& result) { os << "{\n" << " \"latency(ms)\": " << std::fixed << std::setprecision(2) << result.latency_ << ",\n" << " \"tflops(TFlops)\": " << result.tflops_ << ",\n" << " \"bandwidth(GB/s)\": " << result.bandwidth_ << "\n" << "}"; return os; } }; struct KernelInstance { std::string name_; GemmProblem problem_; PerformanceResult perf_result_; static bool compare(const KernelInstance& a, const KernelInstance& b, Metric m) { return PerformanceResult::compare(a.perf_result_, b.perf_result_, m); } friend std::ostream& operator<<(std::ostream& os, const KernelInstance& obj) { os << "{\n" << " \"name\": \"" << obj.name_ << "\",\n" << " \"problem\": " << obj.problem_ << ",\n" << " \"perf_result\": " << obj.perf_result_ << "\n" << "}"; return os; } }; struct Setting { int n_warmup_; int n_repeat_; bool is_gpu_timer_; int verify_; int init_method_; bool log_; std::string csv_filename_; bool flush_cache_; int rotating_count_; bool json_output_; }; inline std::string get_rocm_version() { std::ifstream version_file("/opt/rocm/.info/version"); if(version_file.is_open()) { std::string version; std::getline(version_file, version); return version; } return "Unknown"; } template auto calculate_rtol_atol(const ck_tile::index_t K, const ck_tile::index_t kbatch, const float max_accumulated_value) { using ComputeType = std::conditional_t; // Calculate thresholds const auto rtol = ck_tile::get_relative_threshold( ck_tile::integer_divide_ceil(K, kbatch)); const auto atol = ck_tile::get_absolute_threshold( max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); // Calculate error due to split_k accumulation const auto rtol_split_k = ck_tile::get_relative_threshold(kbatch); const auto atol_split_k = ck_tile::get_absolute_threshold( max_accumulated_value, kbatch); // Use higher threshold return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); } /// @brief Function to compare the results of the device and host computations bool compare(std::string instanceName, ck_tile::index_t K, ck_tile::index_t kbatch, ck_tile::HostTensor& c_m_n_dev_result, ck_tile::HostTensor& c_m_n_ref) { const float max_accumulated_value = *std::max_element(c_m_n_ref.mData.begin(), c_m_n_ref.mData.end()); const auto rtol_atol = calculate_rtol_atol( K, kbatch, max_accumulated_value); bool pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_ref, "Error: Incorrect results!", rtol_atol.at(ck_tile::number<0>{}), rtol_atol.at(ck_tile::number<1>{})); std::cout << "For " << instanceName << " Relative error threshold is " << rtol_atol.at(ck_tile::number<0>{}) << " Absolute error threshold is " << rtol_atol.at(ck_tile::number<1>{}) << std::endl; std::cout << "The verification result is:" << (pass ? "correct" : "fail") << std::endl; return pass; } /// @brief Function to get the kernel output with reference implementation on CPU/GPU void gemm_host_reference(int verify, ck_tile::HostTensor& a_m_k, ck_tile::HostTensor& b_k_n, ck_tile::HostTensor& c_m_n_ref, ck_tile::DeviceMem& a_m_k_dev_buf, ck_tile::DeviceMem& b_k_n_dev_buf, ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K, ck_tile::index_t stride_A, ck_tile::index_t stride_B, ck_tile::index_t stride_C) { if(verify == 1) { ck_tile::reference_gemm( a_m_k, b_k_n, c_m_n_ref); } else if(verify == 2) { a_m_k_dev_buf.ToDevice(a_m_k.data()); b_k_n_dev_buf.ToDevice(b_k_n.data()); ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_ref.get_element_space_size_in_bytes()); c_m_n_gpu_buf_ref.SetZero(); ADataType* d_A = static_cast(a_m_k_dev_buf.GetDeviceBuffer()); BDataType* d_B = static_cast(b_k_n_dev_buf.GetDeviceBuffer()); CDataType* d_C = static_cast(c_m_n_gpu_buf_ref.GetDeviceBuffer()); ck_tile::reference_gemm_gpu(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C); c_m_n_gpu_buf_ref.FromDevice(c_m_n_ref.data()); } }