mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 17:55:48 +00:00
[CK_TILE] Restructure Tile Engine's benchmarking and profiling (#4769) ## Motivation This PR introduces a restructure for the benchmarking and profiling aspects of CK Tile's Tile Engine, expanding on the groundwork from this previous https://github.com/ROCm/composable_kernel/pull/3434 and outlined in this [design document](https://amdcloud-my.sharepoint.com/:w:/r/personal/astharai_amd_com/Documents/Restructuring%20Tile%20Engine.docx?d=w14ea28a30718416988ed5ebb759bd3b2&csf=1&web=1&e=l3VBuX). In PR 3434, to reduce repeated code we implemented: - Base class that centralizes common functionality and provides a default implementation (Universal GEMM) - Child classes for GEMM variants override virtual functions to handle variant-specific behavior This refactoring in this PR follows the same process and should greatly reduce the duplicated code present in Tile Engine and make it simpler to add in new operations, increasing scalability. ## Technical Details The files have been refactored around new base structs for benchmarks, profiling and problem descriptions. The new base structs are: - GemmProblem - GemmBenchmark - GemmProfiler Universal GEMM, Preshuffle GEMM, and Multi-D GEMM all have child classes that will inherit from these base structs overriding only what differs per variant. All common functions across the benchmarking and profiling files have been moved into newly added common utility files under the commons/ directory. The new utility files are: - utils.hpp: common functions for the benchmarking and profiling process - benchmark_utils.py: common utility functions for the benchmark generation ## Test Plan I tested using the existing tests for Tile Engine. ## Test Result All tests passed. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
191 lines
7.4 KiB
C++
191 lines
7.4 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#pragma once
|
|
|
|
#include <iostream>
|
|
#include <fstream>
|
|
#include <iomanip>
|
|
#include <vector>
|
|
#include <algorithm>
|
|
#include <functional>
|
|
#include <tuple>
|
|
|
|
#include "ck_tile/host/device_prop.hpp"
|
|
#include "ck_tile/ops/gemm.hpp"
|
|
#include "gemm_benchmark.hpp"
|
|
|
|
template <typename Gemm, typename Problem, typename GemmArgs>
|
|
class GemmProfiler
|
|
{
|
|
public:
|
|
static Gemm& instance(Settings setting)
|
|
{
|
|
static Gemm instance{setting};
|
|
return instance;
|
|
}
|
|
|
|
// Overload for single kernel benchmarking
|
|
void benchmark(Problem& gemm_problem,
|
|
std::function<float(const GemmArgs&, const ck_tile::stream_config&)> kernel_func)
|
|
{
|
|
// Create a vector with a single callable that returns both name and time
|
|
std::vector<
|
|
std::function<std::tuple<std::string, float>(GemmArgs&, const ck_tile::stream_config&)>>
|
|
callables;
|
|
|
|
callables.push_back([kernel_func](GemmArgs& args, const ck_tile::stream_config& stream) {
|
|
float time = kernel_func(args, stream);
|
|
return std::make_tuple(std::string(KERNEL_NAME), time);
|
|
});
|
|
|
|
benchmark(gemm_problem, callables);
|
|
}
|
|
|
|
virtual void benchmark(Problem& gemm_problem,
|
|
std::vector<std::function<std::tuple<std::string, float>(
|
|
GemmArgs&, const ck_tile::stream_config&)>>& callables) = 0;
|
|
|
|
void process_result(const Problem& gemm_problem,
|
|
ck_tile::DeviceMem& c_m_n_dev_buf,
|
|
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
|
|
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
|
const std::tuple<std::string, float>& kernel_run_result)
|
|
{
|
|
auto [name, avg_time] = kernel_run_result;
|
|
using DDataType = typename get_DsDataType<Problem>::type;
|
|
|
|
KernelInstance<Problem> kernel_instance{name, gemm_problem, {-1.0f, -1.0f, -1.0f}};
|
|
|
|
// compute performance metric
|
|
std::size_t flop = std::size_t(2) * gemm_problem.m_ * gemm_problem.n_ * gemm_problem.k_;
|
|
std::size_t num_byte = sizeof(ADataType) * gemm_problem.m_ * gemm_problem.k_ +
|
|
sizeof(BDataType) * gemm_problem.n_ * gemm_problem.k_ +
|
|
sizeof(CDataType) * gemm_problem.m_ * gemm_problem.n_;
|
|
|
|
if constexpr(!std::is_void_v<DDataType>)
|
|
{
|
|
ck_tile::static_for<0, DDataType::size(), 1>{}([&](auto i) {
|
|
using DType = ck_tile::remove_cvref_t<std::tuple_element_t<i, DDataType>>;
|
|
num_byte += sizeof(DType) * gemm_problem.m_ * gemm_problem.n_;
|
|
flop += gemm_problem.m_ * gemm_problem.n_;
|
|
});
|
|
}
|
|
|
|
// update
|
|
kernel_instance.perf_result_.latency_ = avg_time;
|
|
kernel_instance.perf_result_.tflops_ = static_cast<float>(flop) / 1.E9 / avg_time;
|
|
kernel_instance.perf_result_.bandwidth_ = num_byte / 1.E6 / avg_time;
|
|
|
|
if(setting_.log > 0 && !setting_.json_output)
|
|
{
|
|
std::cout << kernel_instance << std::endl;
|
|
}
|
|
|
|
// verify result
|
|
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
|
int split_k = 1;
|
|
if constexpr(std::is_same_v<Problem, GemmProblem>)
|
|
{
|
|
split_k = gemm_problem.split_k_;
|
|
}
|
|
bool verified_correct =
|
|
!setting_.verify ||
|
|
compare<Problem>(name, gemm_problem.k_, split_k, c_m_n_dev_result, c_m_n_host_result);
|
|
|
|
if(verified_correct)
|
|
{
|
|
kernel_instances_.emplace_back(kernel_instance);
|
|
}
|
|
else
|
|
{
|
|
std::cout << "Verification failed, skip kernel: " << name << std::endl;
|
|
}
|
|
|
|
// clear tensor
|
|
c_m_n_dev_buf.SetZero();
|
|
c_m_n_dev_result.SetZero();
|
|
}
|
|
|
|
KernelInstance<Problem> select_best_instance(Metric metric)
|
|
{
|
|
if(kernel_instances_.empty())
|
|
throw std::runtime_error("Empty instances");
|
|
|
|
auto kernel_instance = *std::max_element(kernel_instances_.begin(),
|
|
kernel_instances_.end(),
|
|
[metric](const auto& a, const auto& b) {
|
|
return PerformanceResult::compare(
|
|
b.perf_result_, a.perf_result_, metric);
|
|
});
|
|
|
|
if(setting_.json_output)
|
|
{
|
|
// Output clean JSON only
|
|
std::cout << kernel_instance << std::endl;
|
|
}
|
|
else
|
|
{
|
|
std::cout << "**********************************" << std::endl;
|
|
std::cout << "According to given metrics: " << get_metric_name(metric) << "\n"
|
|
<< "Current kernel performance is: " << kernel_instance << std::endl;
|
|
std::cout << "**********************************" << std::endl;
|
|
}
|
|
|
|
if(!setting_.csv_filename.empty())
|
|
{
|
|
std::ofstream file(setting_.csv_filename + ".csv", std::ios::app);
|
|
|
|
if(!file.is_open())
|
|
{
|
|
std::cerr << "Warning: Failed to open CSV file for writing." << std::endl;
|
|
}
|
|
else
|
|
{
|
|
if(file.tellp() == 0)
|
|
{
|
|
file << "rocm_version,device_name,"
|
|
<< "split_k,m,n,k,stride_a,stride_b,stride_c,"
|
|
<< "dtype_a,dtype_b,dtype_acc,dtype_c," << "layout_a,layout_b,layout_c,"
|
|
<< "structured_sparsity," << "name,"
|
|
<< "latency(ms),tflops(TFlops),bandwidth(GB/s),metric\n";
|
|
}
|
|
|
|
const auto& problem = kernel_instance.problem_;
|
|
const auto& name = kernel_instance.name_;
|
|
const auto& perf = kernel_instance.perf_result_;
|
|
|
|
file << get_rocm_version() << "," << ck_tile::get_device_name() << ","
|
|
<< problem.split_k_ << "," << problem.m_ << "," << problem.n_ << ","
|
|
<< problem.k_ << "," << problem.stride_a_ << "," << problem.stride_b_ << ","
|
|
<< problem.stride_c_ << "," << problem.dtype_a_ << "," << problem.dtype_b_
|
|
<< "," << problem.dtype_acc_ << "," << problem.dtype_c_ << ","
|
|
<< problem.layout_a_ << "," << problem.layout_b_ << "," << problem.layout_c_
|
|
<< "," << problem.structured_sparsity_ << "," << name << "," << std::fixed
|
|
<< std::setprecision(4) << perf.latency_ << "," << std::fixed
|
|
<< std::setprecision(4) << perf.tflops_ << "," << std::fixed
|
|
<< std::setprecision(4) << perf.bandwidth_ << "," << get_metric_name(metric)
|
|
<< "\n";
|
|
|
|
if(!file)
|
|
{
|
|
std::cerr << "Warning: Error occurred while writing to CSV file." << std::endl;
|
|
}
|
|
}
|
|
}
|
|
|
|
return kernel_instance;
|
|
}
|
|
|
|
GemmProfiler(const GemmProfiler&) = delete;
|
|
GemmProfiler& operator=(const GemmProfiler&) = delete;
|
|
|
|
protected:
|
|
virtual ~GemmProfiler() { kernel_instances_.clear(); }
|
|
GemmProfiler(Settings setting) : setting_(setting) {}
|
|
|
|
Settings setting_;
|
|
|
|
std::vector<KernelInstance<Problem>> kernel_instances_;
|
|
};
|