mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 04:37:02 +00:00
Users/tlakshma/ck/tile engine develop ## Motivation This PR adds multiple new GPU kernel benchmarking operations to the CK Tile Engine, expanding its coverage of GEMM-family operations: - **gemm_multi_abd**: GEMM with multiple A, B, and D tensors, enabling epilogue patterns such as scale/bias fusion. - **batched_contraction**: Batched tensor contraction supporting multi-dimensional batch (G), M, N, and K dimensions, targeting workloads where the contraction indices span more than one logical axis. - **mx_gemm**: MX-format GEMM with microscaling (e8m0) scale tensors. - **gemm_rowcolquant**: Block-scale GEMM with row/column quantization. - **gemm_tensor_quant**: Block-scale GEMM with tensor quantization. - **grouped_gemm_rowcolquant**: Grouped GEMM with row/column quantization. - **grouped_gemm_tensorquant**: Grouped GEMM with tensor quantization. - **batched_gemm**: Batched GEMM benchmarking support. ## Technical Details ### gemm_multi_abd - New subdirectory: tile_engine/ops/gemm/gemm_multi_abd/ - CMakeLists.txt follows the same individual-target pattern as gemm_universal / gemm_multi_d. - gemm_multi_abd_instance_builder.py subclasses GemmKernelBuilder from the shared gemm_instance_builder.py. - gemm_multi_abd_benchmark.py delegates to the shared GemmBenchmark parent class. - Configs: default_config.json, default_ci_config.json, user_provided_config.json. - Supported GPU targets: gfx90a, gfx942, gfx950, gfx1201. ### batched_contraction - New subdirectory: tile_engine/ops/gemm/batched_contraction/ - Extends GemmKernelBuilder via BatchedContractionKernelBuilder, adding num_dim_g, num_dim_m, num_dim_n, num_dim_k, num_d_tensors, and elementwise_function parameters. - Layout string uses 3-character encoding (A+B+E), e.g. rcr. - Self-contained benchmark sweep driver (batched_contraction_benchmark.py) with JSON/CSV export and best-kernel selection. - Supported GPU targets: gfx90a, gfx942, gfx950. ### mx_gemm - New subdirectory: tile_engine/ops/gemm/mx_gemm/ - Supports MX-format (e8m0) microscaling for A and B scale tensors. ### block_scale_gemm (gemm_rowcolquant, gemm_tensor_quant) - New subdirectory: tile_engine/ops/gemm/block_scale_gemm/ - gemm_rowcolquant: row/column quantization epilogue. - gemm_tensor_quant: tensor-level quantization epilogue. ### grouped_gemm_quant (grouped_gemm_rowcolquant, grouped_gemm_tensorquant) - New subdirectory: tile_engine/ops/gemm/grouped_gemm_quant/ - grouped_gemm_rowcolquant: grouped GEMM with row/column quantization. - grouped_gemm_tensorquant: grouped GEMM with tensor quantization. ### batched_gemm - New subdirectory: tile_engine/ops/gemm/batched_gemm/ - Batched GEMM benchmark support wired into the sampling/active-op lists. All new ops are registered in op_weights.json for budget allocation and wired into the active-op sampling lists in CMakeLists.txt. ## Test Plan <!-- Explain any relevant testing done to verify this PR. --> ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
241 lines
8.8 KiB
C++
241 lines
8.8 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#pragma once
|
|
|
|
#include <iostream>
|
|
#include <fstream>
|
|
#include <iomanip>
|
|
#include <vector>
|
|
#include <algorithm>
|
|
#include <functional>
|
|
#include <tuple>
|
|
#include <utility>
|
|
#include <type_traits>
|
|
|
|
#include "ck_tile/host/device_prop.hpp"
|
|
#include "ck_tile/ops/gemm.hpp"
|
|
#include "gemm_benchmark.hpp"
|
|
|
|
template <typename T, typename = void>
|
|
struct has_split_k_member : std::false_type
|
|
{
|
|
};
|
|
|
|
template <typename T>
|
|
struct has_split_k_member<T, std::void_t<decltype(std::declval<T>().split_k_)>> : std::true_type
|
|
{
|
|
};
|
|
|
|
template <typename Gemm, typename Problem, typename GemmArgs>
|
|
class GemmProfiler
|
|
{
|
|
public:
|
|
static Gemm& instance(Settings setting)
|
|
{
|
|
static Gemm instance{setting};
|
|
return instance;
|
|
}
|
|
|
|
// Overload for single kernel benchmarking
|
|
void benchmark(Problem& gemm_problem,
|
|
std::function<float(const GemmArgs&, const ck_tile::stream_config&)> kernel_func)
|
|
{
|
|
// Create a vector with a single callable that returns both name and time
|
|
std::vector<
|
|
std::function<std::tuple<std::string, float>(GemmArgs&, const ck_tile::stream_config&)>>
|
|
callables;
|
|
|
|
callables.push_back([kernel_func](GemmArgs& args, const ck_tile::stream_config& stream) {
|
|
float time = kernel_func(args, stream);
|
|
return std::make_tuple(std::string(KERNEL_NAME), time);
|
|
});
|
|
|
|
benchmark(gemm_problem, callables);
|
|
}
|
|
|
|
virtual void benchmark(Problem& gemm_problem,
|
|
std::vector<std::function<std::tuple<std::string, float>(
|
|
GemmArgs&, const ck_tile::stream_config&)>>& callables) = 0;
|
|
|
|
void process_result(const Problem& gemm_problem,
|
|
ck_tile::DeviceMem& c_m_n_dev_buf,
|
|
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
|
|
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
|
const std::tuple<std::string, float>& kernel_run_result)
|
|
{
|
|
auto [name, avg_time] = kernel_run_result;
|
|
|
|
KernelInstance<Problem> kernel_instance{name, gemm_problem, {-1.0f, -1.0f, -1.0f}};
|
|
|
|
// compute performance metric
|
|
std::size_t flop = get_flop_count(gemm_problem);
|
|
std::size_t num_byte = get_byte_count(gemm_problem);
|
|
|
|
// update
|
|
kernel_instance.perf_result_.latency_ = avg_time;
|
|
kernel_instance.perf_result_.tflops_ = static_cast<float>(flop) / 1.E9 / avg_time;
|
|
kernel_instance.perf_result_.bandwidth_ = num_byte / 1.E6 / avg_time;
|
|
|
|
if(setting_.log > 0 && !setting_.json_output)
|
|
{
|
|
std::cout << kernel_instance << std::endl;
|
|
}
|
|
|
|
// verify result
|
|
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
|
bool verified_correct =
|
|
!setting_.verify || compare<Problem>(name,
|
|
gemm_problem.k_,
|
|
get_verification_split_k(gemm_problem),
|
|
c_m_n_dev_result,
|
|
c_m_n_host_result);
|
|
|
|
if(verified_correct)
|
|
{
|
|
kernel_instances_.emplace_back(kernel_instance);
|
|
}
|
|
else
|
|
{
|
|
std::cout << "Verification failed, skip kernel: " << name << std::endl;
|
|
}
|
|
|
|
// clear tensor
|
|
c_m_n_dev_buf.SetZero();
|
|
c_m_n_dev_result.SetZero();
|
|
}
|
|
|
|
KernelInstance<Problem> select_best_instance(Metric metric)
|
|
{
|
|
if(kernel_instances_.empty())
|
|
throw std::runtime_error("Empty instances");
|
|
|
|
auto kernel_instance = *std::max_element(kernel_instances_.begin(),
|
|
kernel_instances_.end(),
|
|
[metric](const auto& a, const auto& b) {
|
|
return PerformanceResult::compare(
|
|
b.perf_result_, a.perf_result_, metric);
|
|
});
|
|
|
|
if(setting_.json_output)
|
|
{
|
|
// Output clean JSON only
|
|
std::cout << kernel_instance << std::endl;
|
|
}
|
|
else
|
|
{
|
|
std::cout << "**********************************" << std::endl;
|
|
std::cout << "According to given metrics: " << get_metric_name(metric) << "\n"
|
|
<< "Current kernel performance is: " << kernel_instance << std::endl;
|
|
std::cout << "**********************************" << std::endl;
|
|
}
|
|
|
|
if(!setting_.csv_filename.empty())
|
|
{
|
|
std::ofstream file(setting_.csv_filename + ".csv", std::ios::app);
|
|
|
|
if(!file.is_open())
|
|
{
|
|
std::cerr << "Warning: Failed to open CSV file for writing." << std::endl;
|
|
}
|
|
else
|
|
{
|
|
if(file.tellp() == 0)
|
|
{
|
|
write_csv_header(file);
|
|
}
|
|
|
|
write_csv_row(file, kernel_instance, metric);
|
|
|
|
if(!file)
|
|
{
|
|
std::cerr << "Warning: Error occurred while writing to CSV file." << std::endl;
|
|
}
|
|
}
|
|
}
|
|
|
|
return kernel_instance;
|
|
}
|
|
|
|
GemmProfiler(const GemmProfiler&) = delete;
|
|
GemmProfiler& operator=(const GemmProfiler&) = delete;
|
|
|
|
protected:
|
|
virtual ~GemmProfiler() { kernel_instances_.clear(); }
|
|
GemmProfiler(Settings setting) : setting_(setting) {}
|
|
|
|
virtual std::size_t get_flop_count(const Problem& gemm_problem) const
|
|
{
|
|
using DDataType = typename get_DsDataType<Problem>::type;
|
|
|
|
std::size_t flop = std::size_t(2) * gemm_problem.m_ * gemm_problem.n_ * gemm_problem.k_;
|
|
if constexpr(!std::is_void_v<DDataType>)
|
|
{
|
|
ck_tile::static_for<0, DDataType::size(), 1>{}([&](auto i) {
|
|
using DType = ck_tile::remove_cvref_t<std::tuple_element_t<i, DDataType>>;
|
|
static_cast<void>(sizeof(DType));
|
|
flop += gemm_problem.m_ * gemm_problem.n_;
|
|
});
|
|
}
|
|
return flop;
|
|
}
|
|
|
|
virtual std::size_t get_byte_count(const Problem& gemm_problem) const
|
|
{
|
|
using DDataType = typename get_DsDataType<Problem>::type;
|
|
|
|
std::size_t num_byte = sizeof(ADataType) * gemm_problem.m_ * gemm_problem.k_ +
|
|
sizeof(BDataType) * gemm_problem.n_ * gemm_problem.k_ +
|
|
sizeof(CDataType) * gemm_problem.m_ * gemm_problem.n_;
|
|
|
|
if constexpr(!std::is_void_v<DDataType>)
|
|
{
|
|
ck_tile::static_for<0, DDataType::size(), 1>{}([&](auto i) {
|
|
using DType = ck_tile::remove_cvref_t<std::tuple_element_t<i, DDataType>>;
|
|
num_byte += sizeof(DType) * gemm_problem.m_ * gemm_problem.n_;
|
|
});
|
|
}
|
|
return num_byte;
|
|
}
|
|
|
|
virtual int get_verification_split_k(const Problem& gemm_problem) const
|
|
{
|
|
if constexpr(has_split_k_member<Problem>::value)
|
|
{
|
|
return gemm_problem.split_k_;
|
|
}
|
|
return 1;
|
|
}
|
|
|
|
virtual void write_csv_header(std::ostream& os) const
|
|
{
|
|
os << "rocm_version,device_name," << "split_k,m,n,k,stride_a,stride_b,stride_c,"
|
|
<< "dtype_a,dtype_b,dtype_acc,dtype_c," << "layout_a,layout_b,layout_c,"
|
|
<< "structured_sparsity," << "name,"
|
|
<< "latency(ms),tflops(TFlops),bandwidth(GB/s),metric\n";
|
|
}
|
|
|
|
virtual void write_csv_row(std::ostream& os,
|
|
const KernelInstance<Problem>& kernel_instance,
|
|
Metric metric) const
|
|
{
|
|
const auto& problem = kernel_instance.problem_;
|
|
const auto& name = kernel_instance.name_;
|
|
const auto& perf = kernel_instance.perf_result_;
|
|
|
|
os << get_rocm_version() << "," << ck_tile::get_device_name() << ","
|
|
<< get_verification_split_k(problem) << "," << problem.m_ << "," << problem.n_ << ","
|
|
<< problem.k_ << "," << problem.stride_a_ << "," << problem.stride_b_ << ","
|
|
<< problem.stride_c_ << "," << problem.dtype_a_ << "," << problem.dtype_b_ << ","
|
|
<< problem.dtype_acc_ << "," << problem.dtype_c_ << "," << problem.layout_a_ << ","
|
|
<< problem.layout_b_ << "," << problem.layout_c_ << "," << problem.structured_sparsity_
|
|
<< "," << name << "," << std::fixed << std::setprecision(4) << perf.latency_ << ","
|
|
<< std::fixed << std::setprecision(4) << perf.tflops_ << "," << std::fixed
|
|
<< std::setprecision(4) << perf.bandwidth_ << "," << get_metric_name(metric) << "\n";
|
|
}
|
|
|
|
Settings setting_;
|
|
|
|
std::vector<KernelInstance<Problem>> kernel_instances_;
|
|
};
|