Files
composable_kernel/tile_engine/ops/gemm/gemm_benchmark.hpp
Thrupti Raj Lakshmana Gowda d7609923b6 [rocm-libraries] ROCm/rocm-libraries#7919 (commit 061001d)
Users/tlakshma/ck/tile engine develop

## Motivation

This PR adds multiple new GPU kernel benchmarking operations to the CK
Tile Engine, expanding its coverage of GEMM-family operations:

- **gemm_multi_abd**: GEMM with multiple A, B, and D tensors, enabling
epilogue patterns such as scale/bias fusion.
- **batched_contraction**: Batched tensor contraction supporting
multi-dimensional batch (G), M, N, and K dimensions, targeting workloads
where the contraction indices span more than one logical axis.
- **mx_gemm**: MX-format GEMM with microscaling (e8m0) scale tensors.
- **gemm_rowcolquant**: Block-scale GEMM with row/column quantization.
- **gemm_tensor_quant**: Block-scale GEMM with tensor quantization.
- **grouped_gemm_rowcolquant**: Grouped GEMM with row/column
quantization.
- **grouped_gemm_tensorquant**: Grouped GEMM with tensor quantization.
- **batched_gemm**: Batched GEMM benchmarking support.

## Technical Details

### gemm_multi_abd

  - New subdirectory: tile_engine/ops/gemm/gemm_multi_abd/
- CMakeLists.txt follows the same individual-target pattern as
gemm_universal / gemm_multi_d.
- gemm_multi_abd_instance_builder.py subclasses GemmKernelBuilder from
the shared gemm_instance_builder.py.
- gemm_multi_abd_benchmark.py delegates to the shared GemmBenchmark
parent class.
- Configs: default_config.json, default_ci_config.json,
user_provided_config.json.
  - Supported GPU targets: gfx90a, gfx942, gfx950, gfx1201.

### batched_contraction

  - New subdirectory: tile_engine/ops/gemm/batched_contraction/
- Extends GemmKernelBuilder via BatchedContractionKernelBuilder, adding
num_dim_g, num_dim_m, num_dim_n, num_dim_k, num_d_tensors, and
elementwise_function parameters.
  - Layout string uses 3-character encoding (A+B+E), e.g. rcr.
- Self-contained benchmark sweep driver
(batched_contraction_benchmark.py) with JSON/CSV export and best-kernel
selection.
  - Supported GPU targets: gfx90a, gfx942, gfx950.

### mx_gemm

  - New subdirectory: tile_engine/ops/gemm/mx_gemm/
  - Supports MX-format (e8m0) microscaling for A and B scale tensors.

### block_scale_gemm (gemm_rowcolquant, gemm_tensor_quant)

  - New subdirectory: tile_engine/ops/gemm/block_scale_gemm/
  - gemm_rowcolquant: row/column quantization epilogue.
  - gemm_tensor_quant: tensor-level quantization epilogue.

### grouped_gemm_quant (grouped_gemm_rowcolquant,
grouped_gemm_tensorquant)

  - New subdirectory: tile_engine/ops/gemm/grouped_gemm_quant/
  - grouped_gemm_rowcolquant: grouped GEMM with row/column quantization.
  - grouped_gemm_tensorquant: grouped GEMM with tensor quantization.

### batched_gemm

  - New subdirectory: tile_engine/ops/gemm/batched_gemm/
- Batched GEMM benchmark support wired into the sampling/active-op
lists.

All new ops are registered in op_weights.json for budget allocation and
wired into the active-op sampling lists in CMakeLists.txt.

## Test Plan

<!-- Explain any relevant testing done to verify this PR. -->

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-06-11 20:38:38 +00:00

128 lines
4.4 KiB
C++

// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <iostream>
#include <string>
#include <fstream>
#include <stdexcept>
#include <iomanip>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "common/utils.hpp"
// Data types and Layouts are defined by the generated kernel headers
// No hardcoded type definitions here to avoid conflicts
struct GemmProblem
{
int split_k_;
int m_, n_, k_;
int stride_a_, stride_b_, stride_c_;
std::string dtype_a_, dtype_b_, dtype_acc_, dtype_c_;
std::string layout_a_, layout_b_, layout_c_;
bool structured_sparsity_;
friend std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os,
const GemmProblem& problem)
{
os << "{\n"
<< " \"split_k\":" << problem.split_k_ << ",\n"
<< " \"m\":" << problem.m_ << ",\n"
<< " \"n\":" << problem.n_ << ",\n"
<< " \"k\":" << problem.k_ << ",\n"
<< " \"stride_a\":" << problem.stride_a_ << ",\n"
<< " \"stride_b\":" << problem.stride_b_ << ",\n"
<< " \"stride_c\":" << problem.stride_c_ << ",\n"
<< " \"dtype_a\":\"" << problem.dtype_a_ << "\",\n"
<< " \"dtype_b\":\"" << problem.dtype_b_ << "\",\n"
<< " \"dtype_acc\":\"" << problem.dtype_acc_ << "\",\n"
<< " \"dtype_c\":\"" << problem.dtype_c_ << "\",\n"
<< " \"layout_a\":\"" << problem.layout_a_ << "\",\n"
<< " \"layout_b\":\"" << problem.layout_b_ << "\",\n"
<< " \"layout_c\":\"" << problem.layout_c_ << "\",\n"
<< " \"structured_sparsity\":" << (problem.structured_sparsity_ ? "true" : "false")
<< "\n"
<< "}";
return os;
}
};
// Detect Problem::DsDataType, default to void when absent
template <class T, class = void>
struct get_DsDataType
{
using type = void;
};
template <class T>
struct get_DsDataType<T, std::void_t<typename T::DsDataType>>
{
using type = typename T::DsDataType;
};
// Detect Problem::D0DataType, default to void when absent
template <class T, class = void>
struct get_D0DataType
{
using type = void;
};
template <class T>
struct get_D0DataType<T, std::void_t<typename T::D0DataType>>
{
using type = typename T::D0DataType;
};
/// @brief Generic compare: all types provided explicitly as template parameters
template <typename AType, typename BType, typename AccType, typename OutType, typename DType = void>
bool compare(std::string instanceName,
ck_tile::index_t K,
ck_tile::index_t kbatch,
ck_tile::HostTensor<OutType>& out_dev_result,
ck_tile::HostTensor<OutType>& out_host_result)
{
const float max_accumulated_value =
*std::max_element(out_host_result.mData.begin(), out_host_result.mData.end());
auto rtol_atol = [&] {
if constexpr(std::is_void_v<DType>)
{
return calculate_rtol_atol<AType, BType, AccType, OutType>(
K, kbatch, max_accumulated_value);
}
else
{
return calculate_rtol_atol<AType, BType, DType, AccType, OutType>(
K, kbatch, max_accumulated_value);
}
}();
bool pass = ck_tile::check_err(out_dev_result,
out_host_result,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
std::cout << "For " << instanceName << " Relative error threshold is "
<< rtol_atol.at(ck_tile::number<0>{}) << " Absolute error threshold is "
<< rtol_atol.at(ck_tile::number<1>{}) << std::endl;
std::cout << "The verification result is:" << (pass ? "correct" : "fail") << std::endl;
return pass;
}
/// @brief Backward-compatible compare: deduces types from global aliases and Problem trait
template <typename Problem, typename OutType>
bool compare(std::string instanceName,
ck_tile::index_t K,
ck_tile::index_t kbatch,
ck_tile::HostTensor<OutType>& out_dev_result,
ck_tile::HostTensor<OutType>& out_host_result)
{
using DDataType = typename get_D0DataType<Problem>::type;
return compare<ADataType, BDataType, AccDataType, OutType, DDataType>(
instanceName, K, kbatch, out_dev_result, out_host_result);
}