// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #pragma once #include #include #include #include "ck_tile/host/device_prop.hpp" #include "ck_tile/ops/gemm.hpp" #include "gemm_streamk_benchmark.hpp" class GemmProfiler { public: static GemmProfiler& instance(Setting setting) { static GemmProfiler instance{setting}; return instance; } // Overload for single kernel benchmarking void benchmark(GemmProblem& gemm_problem, std::function( const ck_tile::StreamKHostArgs&, const ck_tile::stream_config&)> kernel_func) { // Create a vector with a single callable that returns name, time, and num_wgs_per_tile std::vector( ck_tile::StreamKHostArgs&, const ck_tile::stream_config&)>> callables; callables.push_back( [kernel_func](ck_tile::StreamKHostArgs& args, const ck_tile::stream_config& stream) { auto [time, num_wgs_per_tile] = kernel_func(args, stream); return std::make_tuple(std::string(KERNEL_NAME), time, num_wgs_per_tile); }); benchmark(gemm_problem, callables); } void benchmark(GemmProblem& gemm_problem, std::vector( ck_tile::StreamKHostArgs&, const ck_tile::stream_config&)>>& callables) { const ALayout layout_a = ALayout{}; const BLayout layout_b = BLayout{}; const CLayout layout_c = CLayout{}; gemm_problem.stride_a_ = ck_tile::get_default_stride( gemm_problem.m_, gemm_problem.k_, gemm_problem.stride_a_, is_row_major(layout_a)); gemm_problem.stride_b_ = ck_tile::get_default_stride( gemm_problem.k_, gemm_problem.n_, gemm_problem.stride_b_, is_row_major(layout_b)); gemm_problem.stride_c_ = ck_tile::get_default_stride( gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c)); ck_tile::HostTensor a_m_k(ck_tile::host_tensor_descriptor( gemm_problem.m_, gemm_problem.k_, gemm_problem.stride_a_, is_row_major(layout_a))); ck_tile::HostTensor b_k_n(ck_tile::host_tensor_descriptor( gemm_problem.k_, gemm_problem.n_, gemm_problem.stride_b_, is_row_major(layout_b))); ck_tile::HostTensor c_m_n_dev_result(ck_tile::host_tensor_descriptor( gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c))); if(setting_.init_method_ == 0) { ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k); ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n); } else if(setting_.init_method_ == 1) { ck_tile::FillMonotonicSeq{}(a_m_k); ck_tile::FillMonotonicSeq{}(b_k_n); } else if(setting_.init_method_ == 2) { ck_tile::FillConstant{static_cast(1)}(a_m_k); ck_tile::FillConstant{static_cast(1)}(b_k_n); } else { a_m_k.SetZero(); b_k_n.SetZero(); } if(gemm_problem.structured_sparsity_) { ck_tile::AdjustToStructuredSparsity{}(a_m_k); } ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); if constexpr(std::is_same_v) { // Permute vector pk_i4x4 data for device implementation ck_tile::HostTensor b_k_n_dev = b_k_n; // permute_tensor_b(b_k_n_dev); permute_vectors_i4x4_b(b_k_n_dev); b_k_n_dev_buf.ToDevice(b_k_n_dev.data()); } else { b_k_n_dev_buf.ToDevice(b_k_n.data()); } a_m_k_dev_buf.ToDevice(a_m_k.data()); c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); ck_tile::StreamKHostArgs gemm_args{a_m_k_dev_buf.GetDeviceBuffer(), b_k_n_dev_buf.GetDeviceBuffer(), c_m_n_dev_buf.GetDeviceBuffer(), gemm_problem.m_, gemm_problem.n_, gemm_problem.k_, gemm_problem.stride_a_, gemm_problem.stride_b_, gemm_problem.stride_c_}; ck_tile::HostTensor c_m_n_host_result(ck_tile::host_tensor_descriptor( gemm_problem.m_, gemm_problem.n_, gemm_problem.stride_c_, is_row_major(layout_c))); if(setting_.verify_) { gemm_host_reference(setting_.verify_, a_m_k, b_k_n, c_m_n_host_result, a_m_k_dev_buf, b_k_n_dev_buf, gemm_problem.m_, gemm_problem.n_, gemm_problem.k_, gemm_problem.stride_a_, gemm_problem.stride_b_, gemm_problem.stride_c_); } for(auto& callable : callables) { auto kernel_run_result = callable(gemm_args, ck_tile::stream_config{nullptr, true, setting_.log_, setting_.n_warmup_, setting_.n_repeat_, setting_.is_gpu_timer_, setting_.flush_cache_, setting_.rotating_count_}); process_result(gemm_problem, c_m_n_dev_buf, c_m_n_host_result, c_m_n_dev_result, kernel_run_result); } } void process_result(const GemmProblem& gemm_problem, ck_tile::DeviceMem& c_m_n_dev_buf, ck_tile::HostTensor& c_m_n_host_result, ck_tile::HostTensor& c_m_n_dev_result, const std::tuple& kernel_run_result) { auto [name, avg_time, num_wgs_per_tile] = kernel_run_result; auto dp_persistent = SelectedKernel::UsePersistentKernel ? "PersistentKernel" : "NonPersistentKernel"; auto reduction_strategy = SelectedKernel::reduction_strategy == ck_tile::StreamKReductionStrategy::Atomic ? "Atomic" : SelectedKernel::reduction_strategy == ck_tile::StreamKReductionStrategy::Linear ? "Linear" : "Tree"; KernelInstance kernel_instance{ name, dp_persistent, reduction_strategy, gemm_problem, {-1.0f, -1.0f, -1.0f}}; // compute performance metric std::size_t flop = std::size_t(2) * gemm_problem.m_ * gemm_problem.n_ * gemm_problem.k_; std::size_t num_byte = sizeof(ADataType) * gemm_problem.m_ * gemm_problem.k_ + sizeof(BDataType) * gemm_problem.n_ * gemm_problem.k_ + sizeof(CDataType) * gemm_problem.m_ * gemm_problem.n_; // update kernel_instance.perf_result_.latency_ = avg_time; kernel_instance.perf_result_.tflops_ = static_cast(flop) / 1.E9 / avg_time; kernel_instance.perf_result_.bandwidth_ = num_byte / 1.E6 / avg_time; if(setting_.log_ > 0 && !setting_.json_output_) { std::cout << kernel_instance << std::endl; } // verify result c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); bool verified_correct = !setting_.verify_ || compare(name, gemm_problem.k_, num_wgs_per_tile, c_m_n_dev_result, c_m_n_host_result); if(verified_correct) { kernel_instances_.emplace_back(kernel_instance); } else { std::cout << "Verification failed, skip kernel: " << name << std::endl; } // clear tensor c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); } KernelInstance select_best_instance(Metric metric) { if(kernel_instances_.empty()) throw std::runtime_error("Empty instances"); auto kernel_instance = *std::max_element(kernel_instances_.begin(), kernel_instances_.end(), [metric](const auto& a, const auto& b) { return PerformanceResult::compare( b.perf_result_, a.perf_result_, metric); }); if(setting_.json_output_) { // Output clean JSON only std::cout << kernel_instance << std::endl; } else { std::cout << "**********************************" << std::endl; std::cout << "According to given metrics: " << get_metric_name(metric) << "\n" << "Current kernel performance is: " << kernel_instance << std::endl; std::cout << "**********************************" << std::endl; } if(!setting_.csv_filename_.empty()) { std::ofstream file(setting_.csv_filename_ + ".csv", std::ios::app); if(!file.is_open()) { std::cerr << "Warning: Failed to open CSV file for writing." << std::endl; } else { if(file.tellp() == 0) { file << "rocm_version,device_name," << "split_k,m,n,k,stride_a,stride_b,stride_c," << "dtype_a,dtype_b,dtype_acc,dtype_c," << "layout_a,layout_b,layout_c," << "structured_sparsity," << "dp_persistent," << "reduction_strategy," << "name," << "latency(ms),tflops(TFlops),bandwidth(GB/s),metric\n"; } const auto& problem = kernel_instance.problem_; const auto& name = kernel_instance.name_; const auto& dp_persistent = kernel_instance.dp_persistent_; const auto& reduction_strategy = kernel_instance.reduction_strategy_; const auto& perf = kernel_instance.perf_result_; file << get_rocm_version() << "," << ck_tile::get_device_name() << "," << problem.split_k_ << "," << problem.m_ << "," << problem.n_ << "," << problem.k_ << "," << problem.stride_a_ << "," << problem.stride_b_ << "," << problem.stride_c_ << "," << problem.dtype_a_ << "," << problem.dtype_b_ << "," << problem.dtype_acc_ << "," << problem.dtype_c_ << "," << problem.layout_a_ << "," << problem.layout_b_ << "," << problem.layout_c_ << "," << problem.structured_sparsity_ << "," << dp_persistent << "," << reduction_strategy << "," << name << "," << std::fixed << std::setprecision(4) << perf.latency_ << "," << std::fixed << std::setprecision(4) << perf.tflops_ << "," << std::fixed << std::setprecision(4) << perf.bandwidth_ << "," << get_metric_name(metric) << "\n"; if(!file) { std::cerr << "Warning: Error occurred while writing to CSV file." << std::endl; } } } return kernel_instance; } GemmProfiler(const GemmProfiler&) = delete; GemmProfiler& operator=(const GemmProfiler&) = delete; private: ~GemmProfiler() { kernel_instances_.clear(); } GemmProfiler(Setting setting) : setting_(setting) {} Setting setting_; std::vector kernel_instances_; };