Restructure the Tile Engine to have faster build time and clear config report (#2747)

* Making edits to identify individual compilation issues.

* Minor fix for blob txt files not being created.

* Fixing compilation issues.

* Fixing ordering bug.

* Adding python profiling functionality.

* Setting individual build as default.

* Setting gpu target filtering for tile engine to gfx90a, gfx942 and gfx950.

* update the default running parameters and settings

* Fixing bug with benchmarking, shifting file generation to build instead of config.

* Updating fixes.

* Fixing json output and parsing.

* Disable ccache for tile engine gemm ops because we dont need it.

* Removing duplicate type definition.

* Improving json printing.

* Add the flexibility of different layout and more warp tile support

* Fix extra flag in name of individual kernels.

* Fixing bug with booleans.

* Solve the first patch of the post merge conflict

* Compilation fixes, and cosmetic improvements.

* Yet again compilation fixes after latest changes from develop.

* Fixing python benchmarking script.

---------

Co-authored-by: Vidyasagar Ananthan <vidyasagar.ananthan@amd.com>
Co-authored-by: Vidyasagar Ananthan <vanantha@amd.com>
This commit is contained in:
Thomas Ning
2025-08-30 09:54:18 -04:00
committed by GitHub
parent fcff0043ae
commit 705804d9bf
17 changed files with 3361 additions and 1409 deletions

View File

@@ -0,0 +1,160 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <functional>
#include <tuple>
#include <exception>
#include <sstream>
#include <vector>
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "gemm_profiler.hpp"
#include "gemm_common.hpp"
// The kernel header is included via the compile command line with -include flag
// It defines SelectedKernel struct and KERNEL_NAME
// DataTypeTraits are now defined in gemm_common.hpp
// Create argument parser
inline auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3840", "The value for m dimension. Default is 3840.")
.insert("n", "4096", "The value for n dimension. Default is 4096.")
.insert("k", "2048", "The value for k dimension. Default is 2048.")
.insert("stride_a", "0", "The stride value for tensor A. Default is 0.")
.insert("stride_b", "0", "The stride value for tensor B. Default is 0.")
.insert("stride_c", "0", "The stride value for tensor C. Default is 0.")
.insert("split_k", "1", "The split value for k dimension. Default is 1.")
.insert("verify",
"0",
"The type of validation. Set to 0 for no validation, 1 for validation on CPU, or 2 "
"for validation on GPU. Default is 0, no validation.")
.insert("log",
"false",
"Whether output kernel instance information or not. Possible values are true or "
"false. Default is false")
.insert(
"warmup", "50", "The number of iterations before benchmark the kernel. Default is 50.")
.insert(
"repeat", "100", "The number of iterations to benchmark the kernel. Default is 100.")
.insert("timer",
"true",
"Whether if the timer is gpu timer or not. Possible values are false or true. "
"Default is true.")
.insert("init",
"0",
"The method of tensor initialization. Set to 0 for random, to 1 for linear, or 2 "
"for constant(1). Default is 0, random.")
.insert("flush_cache",
"true",
"To flush cache, possible values are true or false. "
"Default is false.")
.insert("rotating_count", "1000", "number of iterations to rotate the cache. default is 5.")
.insert("metric",
"0",
"Metric with which to measure kernel performance. Set to 0 for latency, 1 for "
"tflops, or 2 for bandwidth. Default is 0, latency.")
.insert("csv_filename",
"",
"The filename of benchmark result. Default is empty (no CSV output).")
.insert("structured_sparsity",
"false",
"Whether use sparsity kernel or not. Possible values are true or false. Default is "
"false")
.insert("json_output",
"false",
"Whether to output results in JSON format only. Possible values are true or false. "
"Default is "
"false");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
void benchmark_gemm_single(const ck_tile::ArgParser& arg_parser)
{
// Use DataTypeTraits to get the actual type names from the generated header
// The generated header defines ADataType, BDataType, AccDataType, CDataType
std::string dtype_a = DataTypeTraits<ADataType>::name;
std::string dtype_b = DataTypeTraits<BDataType>::name;
std::string dtype_acc = DataTypeTraits<AccDataType>::name;
std::string dtype_c = DataTypeTraits<CDataType>::name;
// Layout names from the layout types
std::string layout_a = ALayout::name;
std::string layout_b = BLayout::name;
std::string layout_c = CLayout::name;
// Create GemmProblem struct
GemmProblem gemm_problem{arg_parser.get_int("split_k"),
arg_parser.get_int("m"),
arg_parser.get_int("n"),
arg_parser.get_int("k"),
arg_parser.get_int("stride_a"),
arg_parser.get_int("stride_b"),
arg_parser.get_int("stride_c"),
dtype_a,
dtype_b,
dtype_acc,
dtype_c,
layout_a,
layout_b,
layout_c,
arg_parser.get_bool("structured_sparsity")};
// Create Setting struct
Setting setting{arg_parser.get_int("warmup"),
arg_parser.get_int("repeat"),
arg_parser.get_bool("timer"),
arg_parser.get_int("verify"),
arg_parser.get_int("init"),
arg_parser.get_bool("log"),
arg_parser.get_str("csv_filename"),
arg_parser.get_bool("flush_cache"),
arg_parser.get_int("rotating_count"),
arg_parser.get_bool("json_output")};
// Get the profiler instance
auto& profiler = GemmProfiler::instance(setting);
try
{
// Create a lambda that wraps the kernel launch
auto kernel_func = [](const ck_tile::GemmHostArgs& args,
const ck_tile::stream_config& stream) {
return SelectedKernel::launch(args, stream);
};
// Benchmark the kernel
profiler.benchmark(gemm_problem, kernel_func);
// Select best instance based on metric
profiler.select_best_instance(static_cast<Metric>(arg_parser.get_int("metric")));
}
catch(const std::exception& e)
{
std::cerr << "Benchmark failed: " << e.what() << std::endl;
}
}
int main(int argc, char* argv[])
{
try
{
auto [result, parser] = create_args(argc, argv);
if(!result)
return EXIT_FAILURE;
benchmark_gemm_single(parser);
return 0;
}
catch(const std::exception& e)
{
std::cerr << "Error: " << e.what() << "\n";
return EXIT_FAILURE;
}
}