Files
composable_kernel/tile_engine/ops/gemm/gemm_benchmark.hpp
arai713 5d2fce819d [rocm-libraries] ROCm/rocm-libraries#4769 (commit 72ae66e)
[CK_TILE] Restructure Tile Engine's benchmarking and
 profiling (#4769)

## Motivation
This PR introduces a restructure for the benchmarking and profiling
aspects of CK Tile's Tile Engine, expanding on the groundwork from this
previous https://github.com/ROCm/composable_kernel/pull/3434 and
outlined in this [design
document](https://amdcloud-my.sharepoint.com/:w:/r/personal/astharai_amd_com/Documents/Restructuring%20Tile%20Engine.docx?d=w14ea28a30718416988ed5ebb759bd3b2&csf=1&web=1&e=l3VBuX).
In PR 3434, to reduce repeated code we implemented:

- Base class that centralizes common functionality and provides a
default implementation (Universal GEMM)
- Child classes for GEMM variants override virtual functions to handle
variant-specific behavior

This refactoring in this PR follows the same process and should greatly
reduce the duplicated code present in Tile Engine and make it simpler to
add in new operations, increasing scalability.

## Technical Details
The files have been refactored around new base structs for benchmarks,
profiling and problem descriptions. The new base structs are:

- GemmProblem
- GemmBenchmark
- GemmProfiler

Universal GEMM, Preshuffle GEMM, and Multi-D GEMM all have child classes
that will inherit from these base structs overriding only what differs
per variant.
All common functions across the benchmarking and profiling files have
been moved into newly added common utility files under the commons/
directory. The new utility files are:

- utils.hpp: common functions for the benchmarking and profiling process
- benchmark_utils.py: common utility functions for the benchmark
generation

## Test Plan
I tested using the existing tests for Tile Engine.
## Test Result
All tests passed.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-04-14 17:51:20 +00:00

117 lines
4.0 KiB
C++

// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <iostream>
#include <string>
#include <fstream>
#include <stdexcept>
#include <iomanip>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "common/utils.hpp"
// Data types and Layouts are defined by the generated kernel headers
// No hardcoded type definitions here to avoid conflicts
struct GemmProblem
{
int split_k_;
int m_, n_, k_;
int stride_a_, stride_b_, stride_c_;
std::string dtype_a_, dtype_b_, dtype_acc_, dtype_c_;
std::string layout_a_, layout_b_, layout_c_;
bool structured_sparsity_;
friend std::ostream& operator<<(std::ostream& os, const GemmProblem& problem)
{
os << "{\n"
<< " \"split_k\":" << problem.split_k_ << ",\n"
<< " \"m\":" << problem.m_ << ",\n"
<< " \"n\":" << problem.n_ << ",\n"
<< " \"k\":" << problem.k_ << ",\n"
<< " \"stride_a\":" << problem.stride_a_ << ",\n"
<< " \"stride_b\":" << problem.stride_b_ << ",\n"
<< " \"stride_c\":" << problem.stride_c_ << ",\n"
<< " \"dtype_a\":\"" << problem.dtype_a_ << "\",\n"
<< " \"dtype_b\":\"" << problem.dtype_b_ << "\",\n"
<< " \"dtype_acc\":\"" << problem.dtype_acc_ << "\",\n"
<< " \"dtype_c\":\"" << problem.dtype_c_ << "\",\n"
<< " \"layout_a\":\"" << problem.layout_a_ << "\",\n"
<< " \"layout_b\":\"" << problem.layout_b_ << "\",\n"
<< " \"layout_c\":\"" << problem.layout_c_ << "\",\n"
<< " \"structured_sparsity\":" << (problem.structured_sparsity_ ? "true" : "false")
<< "\n"
<< "}";
return os;
}
};
// Detect Problem::DsDataType, default to void when absent
template <class T, class = void>
struct get_DsDataType
{
using type = void;
};
template <class T>
struct get_DsDataType<T, std::void_t<typename T::DsDataType>>
{
using type = typename T::DsDataType;
};
// Detect Problem::D0DataType, default to void when absent
template <class T, class = void>
struct get_D0DataType
{
using type = void;
};
template <class T>
struct get_D0DataType<T, std::void_t<typename T::D0DataType>>
{
using type = typename T::D0DataType;
};
/// @brief Function to compare the results of the device and host computations
template <typename Problem>
bool compare(std::string instanceName,
ck_tile::index_t K,
ck_tile::index_t kbatch,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
ck_tile::HostTensor<CDataType>& c_m_n_host_result)
{
using DDataType = typename get_D0DataType<Problem>::type;
const float max_accumulated_value =
*std::max_element(c_m_n_host_result.mData.begin(), c_m_n_host_result.mData.end());
// const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
// K, kbatch, max_accumulated_value);
auto rtol_atol = [&] {
if constexpr(std::is_void_v<DDataType>)
{
return calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, kbatch, max_accumulated_value);
}
else
{
return calculate_rtol_atol<ADataType, BDataType, DDataType, AccDataType, CDataType>(
K, kbatch, max_accumulated_value);
}
}();
bool pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_result,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "For " << instanceName << " Relative error threshold is "
<< rtol_atol.at(ck_tile::number<0>{}) << " Absolute error threshold is "
<< rtol_atol.at(ck_tile::number<1>{}) << std::endl;
std::cout << "The verification result is:" << (pass ? "correct" : "fail") << std::endl;
return pass;
}