Files
composable_kernel/tile_engine/ops/gemm/gemm_common.hpp
Thrupti Raj Lakshmana Gowda d7609923b6 [rocm-libraries] ROCm/rocm-libraries#7919 (commit 061001d)
Users/tlakshma/ck/tile engine develop

## Motivation

This PR adds multiple new GPU kernel benchmarking operations to the CK
Tile Engine, expanding its coverage of GEMM-family operations:

- **gemm_multi_abd**: GEMM with multiple A, B, and D tensors, enabling
epilogue patterns such as scale/bias fusion.
- **batched_contraction**: Batched tensor contraction supporting
multi-dimensional batch (G), M, N, and K dimensions, targeting workloads
where the contraction indices span more than one logical axis.
- **mx_gemm**: MX-format GEMM with microscaling (e8m0) scale tensors.
- **gemm_rowcolquant**: Block-scale GEMM with row/column quantization.
- **gemm_tensor_quant**: Block-scale GEMM with tensor quantization.
- **grouped_gemm_rowcolquant**: Grouped GEMM with row/column
quantization.
- **grouped_gemm_tensorquant**: Grouped GEMM with tensor quantization.
- **batched_gemm**: Batched GEMM benchmarking support.

## Technical Details

### gemm_multi_abd

  - New subdirectory: tile_engine/ops/gemm/gemm_multi_abd/
- CMakeLists.txt follows the same individual-target pattern as
gemm_universal / gemm_multi_d.
- gemm_multi_abd_instance_builder.py subclasses GemmKernelBuilder from
the shared gemm_instance_builder.py.
- gemm_multi_abd_benchmark.py delegates to the shared GemmBenchmark
parent class.
- Configs: default_config.json, default_ci_config.json,
user_provided_config.json.
  - Supported GPU targets: gfx90a, gfx942, gfx950, gfx1201.

### batched_contraction

  - New subdirectory: tile_engine/ops/gemm/batched_contraction/
- Extends GemmKernelBuilder via BatchedContractionKernelBuilder, adding
num_dim_g, num_dim_m, num_dim_n, num_dim_k, num_d_tensors, and
elementwise_function parameters.
  - Layout string uses 3-character encoding (A+B+E), e.g. rcr.
- Self-contained benchmark sweep driver
(batched_contraction_benchmark.py) with JSON/CSV export and best-kernel
selection.
  - Supported GPU targets: gfx90a, gfx942, gfx950.

### mx_gemm

  - New subdirectory: tile_engine/ops/gemm/mx_gemm/
  - Supports MX-format (e8m0) microscaling for A and B scale tensors.

### block_scale_gemm (gemm_rowcolquant, gemm_tensor_quant)

  - New subdirectory: tile_engine/ops/gemm/block_scale_gemm/
  - gemm_rowcolquant: row/column quantization epilogue.
  - gemm_tensor_quant: tensor-level quantization epilogue.

### grouped_gemm_quant (grouped_gemm_rowcolquant,
grouped_gemm_tensorquant)

  - New subdirectory: tile_engine/ops/gemm/grouped_gemm_quant/
  - grouped_gemm_rowcolquant: grouped GEMM with row/column quantization.
  - grouped_gemm_tensorquant: grouped GEMM with tensor quantization.

### batched_gemm

  - New subdirectory: tile_engine/ops/gemm/batched_gemm/
- Batched GEMM benchmark support wired into the sampling/active-op
lists.

All new ops are registered in op_weights.json for budget allocation and
wired into the active-op sampling lists in CMakeLists.txt.

## Test Plan

<!-- Explain any relevant testing done to verify this PR. -->

## Test Result

<!-- Briefly summarize test outcomes. -->

## Submission Checklist

- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-06-11 20:38:38 +00:00

113 lines
4.2 KiB
C++

// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <iostream>
#include <functional>
#include <tuple>
#include <exception>
#include <sstream>
#include <vector>
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
// Structure to hold kernel traits for dispatcher
struct KernelTraits
{
std::string pipeline; // compv3, compv4, mem
std::string scheduler; // intrawave, interwave
std::string epilogue; // cshuffle, default
bool pad_m;
bool pad_n;
bool pad_k;
bool persistent;
// Constructor with defaults
KernelTraits()
: pipeline("compv3"),
scheduler("intrawave"),
epilogue("cshuffle"),
pad_m(false),
pad_n(false),
pad_k(false),
persistent(false)
{
}
};
inline void add_common_benchmark_args(ck_tile::ArgParser& arg_parser, int default_verify = 2)
{
arg_parser.insert("m", "3840", "The value for m dimension. Default is 3840.")
.insert("n", "4096", "The value for n dimension. Default is 4096.")
.insert("k", "2048", "The value for k dimension. Default is 2048.")
.insert("stride_a", "0", "The stride value for tensor A. Default is 0.")
.insert("stride_b", "0", "The stride value for tensor B. Default is 0.")
.insert("stride_ds", "0", "The stride value for tensor Ds . Default is 0.")
.insert("stride_c", "0", "The stride value for tensor C. Default is 0.")
.insert("split_k", "1", "The split value for k dimension. Default is 1.")
.insert("verify",
std::to_string(default_verify),
"The type of validation. Set to 0 for no validation, 1 for validation on CPU, or 2 "
"for validation on GPU.")
.insert("log",
"false",
"Whether output kernel instance information or not. Possible values are true or "
"false. Default is false")
.insert(
"warmup", "50", "The number of iterations before benchmark the kernel. Default is 50.")
.insert(
"repeat", "100", "The number of iterations to benchmark the kernel. Default is 100.")
.insert("timer",
"true",
"Whether if the timer is gpu timer or not. Possible values are false or true. "
"Default is true.")
.insert("init",
"0",
"The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 "
"for constant(1). Default is 0, random.")
.insert("flush_cache",
"true",
"To flush cache, possible values are true or false. "
"Default is false.")
.insert("rotating_count", "1000", "number of iterations to rotate the cache. default is 5.")
.insert("metric",
"0",
"Metric with which to measure kernel performance. Set to 0 for latency, 1 for "
"tflops, or 2 for bandwidth. Default is 0, latency.")
.insert("csv_filename",
"",
"The filename of benchmark result. Default is empty (no CSV output).")
.insert("structured_sparsity",
"false",
"Whether use sparsity kernel or not. Possible values are true or false. Default is "
"false")
.insert("json_output",
"false",
"Whether to output results in JSON format only. Possible values are true or false. "
"Default is "
"false");
}
// Create argument parser
inline auto create_args(int argc, char* argv[], int default_verify = 2)
{
ck_tile::ArgParser arg_parser;
add_common_benchmark_args(arg_parser, default_verify);
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename ConfigureArgs>
inline auto create_args(int argc, char* argv[], int default_verify, ConfigureArgs configure_args)
{
ck_tile::ArgParser arg_parser;
add_common_benchmark_args(arg_parser, default_verify);
configure_args(arg_parser);
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}