mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 04:37:02 +00:00
Users/tlakshma/ck/tile engine develop ## Motivation This PR adds multiple new GPU kernel benchmarking operations to the CK Tile Engine, expanding its coverage of GEMM-family operations: - **gemm_multi_abd**: GEMM with multiple A, B, and D tensors, enabling epilogue patterns such as scale/bias fusion. - **batched_contraction**: Batched tensor contraction supporting multi-dimensional batch (G), M, N, and K dimensions, targeting workloads where the contraction indices span more than one logical axis. - **mx_gemm**: MX-format GEMM with microscaling (e8m0) scale tensors. - **gemm_rowcolquant**: Block-scale GEMM with row/column quantization. - **gemm_tensor_quant**: Block-scale GEMM with tensor quantization. - **grouped_gemm_rowcolquant**: Grouped GEMM with row/column quantization. - **grouped_gemm_tensorquant**: Grouped GEMM with tensor quantization. - **batched_gemm**: Batched GEMM benchmarking support. ## Technical Details ### gemm_multi_abd - New subdirectory: tile_engine/ops/gemm/gemm_multi_abd/ - CMakeLists.txt follows the same individual-target pattern as gemm_universal / gemm_multi_d. - gemm_multi_abd_instance_builder.py subclasses GemmKernelBuilder from the shared gemm_instance_builder.py. - gemm_multi_abd_benchmark.py delegates to the shared GemmBenchmark parent class. - Configs: default_config.json, default_ci_config.json, user_provided_config.json. - Supported GPU targets: gfx90a, gfx942, gfx950, gfx1201. ### batched_contraction - New subdirectory: tile_engine/ops/gemm/batched_contraction/ - Extends GemmKernelBuilder via BatchedContractionKernelBuilder, adding num_dim_g, num_dim_m, num_dim_n, num_dim_k, num_d_tensors, and elementwise_function parameters. - Layout string uses 3-character encoding (A+B+E), e.g. rcr. - Self-contained benchmark sweep driver (batched_contraction_benchmark.py) with JSON/CSV export and best-kernel selection. - Supported GPU targets: gfx90a, gfx942, gfx950. ### mx_gemm - New subdirectory: tile_engine/ops/gemm/mx_gemm/ - Supports MX-format (e8m0) microscaling for A and B scale tensors. ### block_scale_gemm (gemm_rowcolquant, gemm_tensor_quant) - New subdirectory: tile_engine/ops/gemm/block_scale_gemm/ - gemm_rowcolquant: row/column quantization epilogue. - gemm_tensor_quant: tensor-level quantization epilogue. ### grouped_gemm_quant (grouped_gemm_rowcolquant, grouped_gemm_tensorquant) - New subdirectory: tile_engine/ops/gemm/grouped_gemm_quant/ - grouped_gemm_rowcolquant: grouped GEMM with row/column quantization. - grouped_gemm_tensorquant: grouped GEMM with tensor quantization. ### batched_gemm - New subdirectory: tile_engine/ops/gemm/batched_gemm/ - Batched GEMM benchmark support wired into the sampling/active-op lists. All new ops are registered in op_weights.json for budget allocation and wired into the active-op sampling lists in CMakeLists.txt. ## Test Plan <!-- Explain any relevant testing done to verify this PR. --> ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
128 lines
4.4 KiB
C++
128 lines
4.4 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#pragma once
|
|
|
|
#include <iostream>
|
|
#include <string>
|
|
#include <fstream>
|
|
#include <stdexcept>
|
|
#include <iomanip>
|
|
|
|
#include "ck_tile/core.hpp"
|
|
#include "ck_tile/host.hpp"
|
|
#include "common/utils.hpp"
|
|
|
|
// Data types and Layouts are defined by the generated kernel headers
|
|
// No hardcoded type definitions here to avoid conflicts
|
|
struct GemmProblem
|
|
{
|
|
int split_k_;
|
|
int m_, n_, k_;
|
|
int stride_a_, stride_b_, stride_c_;
|
|
|
|
std::string dtype_a_, dtype_b_, dtype_acc_, dtype_c_;
|
|
std::string layout_a_, layout_b_, layout_c_;
|
|
|
|
bool structured_sparsity_;
|
|
|
|
friend std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os,
|
|
const GemmProblem& problem)
|
|
{
|
|
os << "{\n"
|
|
<< " \"split_k\":" << problem.split_k_ << ",\n"
|
|
<< " \"m\":" << problem.m_ << ",\n"
|
|
<< " \"n\":" << problem.n_ << ",\n"
|
|
<< " \"k\":" << problem.k_ << ",\n"
|
|
<< " \"stride_a\":" << problem.stride_a_ << ",\n"
|
|
<< " \"stride_b\":" << problem.stride_b_ << ",\n"
|
|
<< " \"stride_c\":" << problem.stride_c_ << ",\n"
|
|
<< " \"dtype_a\":\"" << problem.dtype_a_ << "\",\n"
|
|
<< " \"dtype_b\":\"" << problem.dtype_b_ << "\",\n"
|
|
<< " \"dtype_acc\":\"" << problem.dtype_acc_ << "\",\n"
|
|
<< " \"dtype_c\":\"" << problem.dtype_c_ << "\",\n"
|
|
<< " \"layout_a\":\"" << problem.layout_a_ << "\",\n"
|
|
<< " \"layout_b\":\"" << problem.layout_b_ << "\",\n"
|
|
<< " \"layout_c\":\"" << problem.layout_c_ << "\",\n"
|
|
<< " \"structured_sparsity\":" << (problem.structured_sparsity_ ? "true" : "false")
|
|
<< "\n"
|
|
<< "}";
|
|
return os;
|
|
}
|
|
};
|
|
|
|
// Detect Problem::DsDataType, default to void when absent
|
|
template <class T, class = void>
|
|
struct get_DsDataType
|
|
{
|
|
using type = void;
|
|
};
|
|
|
|
template <class T>
|
|
struct get_DsDataType<T, std::void_t<typename T::DsDataType>>
|
|
{
|
|
using type = typename T::DsDataType;
|
|
};
|
|
|
|
// Detect Problem::D0DataType, default to void when absent
|
|
template <class T, class = void>
|
|
struct get_D0DataType
|
|
{
|
|
using type = void;
|
|
};
|
|
|
|
template <class T>
|
|
struct get_D0DataType<T, std::void_t<typename T::D0DataType>>
|
|
{
|
|
using type = typename T::D0DataType;
|
|
};
|
|
|
|
/// @brief Generic compare: all types provided explicitly as template parameters
|
|
template <typename AType, typename BType, typename AccType, typename OutType, typename DType = void>
|
|
bool compare(std::string instanceName,
|
|
ck_tile::index_t K,
|
|
ck_tile::index_t kbatch,
|
|
ck_tile::HostTensor<OutType>& out_dev_result,
|
|
ck_tile::HostTensor<OutType>& out_host_result)
|
|
{
|
|
const float max_accumulated_value =
|
|
*std::max_element(out_host_result.mData.begin(), out_host_result.mData.end());
|
|
auto rtol_atol = [&] {
|
|
if constexpr(std::is_void_v<DType>)
|
|
{
|
|
return calculate_rtol_atol<AType, BType, AccType, OutType>(
|
|
K, kbatch, max_accumulated_value);
|
|
}
|
|
else
|
|
{
|
|
return calculate_rtol_atol<AType, BType, DType, AccType, OutType>(
|
|
K, kbatch, max_accumulated_value);
|
|
}
|
|
}();
|
|
bool pass = ck_tile::check_err(out_dev_result,
|
|
out_host_result,
|
|
"Error: Incorrect results!",
|
|
rtol_atol.at(ck_tile::number<0>{}),
|
|
rtol_atol.at(ck_tile::number<1>{}));
|
|
|
|
std::cout << "For " << instanceName << " Relative error threshold is "
|
|
<< rtol_atol.at(ck_tile::number<0>{}) << " Absolute error threshold is "
|
|
<< rtol_atol.at(ck_tile::number<1>{}) << std::endl;
|
|
std::cout << "The verification result is:" << (pass ? "correct" : "fail") << std::endl;
|
|
|
|
return pass;
|
|
}
|
|
|
|
/// @brief Backward-compatible compare: deduces types from global aliases and Problem trait
|
|
template <typename Problem, typename OutType>
|
|
bool compare(std::string instanceName,
|
|
ck_tile::index_t K,
|
|
ck_tile::index_t kbatch,
|
|
ck_tile::HostTensor<OutType>& out_dev_result,
|
|
ck_tile::HostTensor<OutType>& out_host_result)
|
|
{
|
|
using DDataType = typename get_D0DataType<Problem>::type;
|
|
return compare<ADataType, BDataType, AccDataType, OutType, DDataType>(
|
|
instanceName, K, kbatch, out_dev_result, out_host_result);
|
|
}
|