[CK_TILE] Stream-K Tile Engine Test Config File Generation (#3662)

* Stream-K smoke test config file generation

This change converts the stream-k smoke tests to use tile engine. Since
the m, n, and k values dependent on the CU count of a device, the
configs are generated during the Configuration Phase.

* Compute GEMM reference on GPU

* Remove redundant Stream-K tests

Removing redundant tests that are now run via tile engine.

* Fix relative and absolute tolerance calculation

This change updates the Stream-K tile engine interface to ensure that
num_wgs_per_tile is propaged and passed into the compare_results
function to calculate the rel and abs tolerance. Before, split-k was
used, which is incorrect for Stream-K since the split-k value is
always 1.

* Cleanup imports, types, and other misc items

This commit makes the following changes:
- Uses Typing module for nested type hints
- Uses quotes around cu_count_arg argument in generate_configs.cmake in
  if statements
- Adds explicit include for tuple in test_gemm_streamk_simple.cpp
- Adds a type for the tiles argument in argparser to check argument
  validity

* Use CU count as return value for better parsing

* Add reduction tests for bf16, fp8, and bf8
This commit is contained in:
Emily Martins
2026-02-03 09:12:15 -07:00
committed by GitHub
parent 3f04d27b68
commit 8cbd09c84a
22 changed files with 522 additions and 406 deletions

View File

@@ -481,8 +481,9 @@ struct SelectedKernel {{
AccDataType,
TileShape,
GemmUniversalTraits>;
static float launch(const ck_tile::StreamKHostArgs& args, const ck_tile::stream_config& stream) {{
static std::tuple<float, ck_tile::index_t> launch(const ck_tile::StreamKHostArgs& args,
const ck_tile::stream_config& stream) {{
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
@@ -562,12 +563,16 @@ struct SelectedKernel {{
workspace_data.SetZero();
}}
}};
const ck_tile::index_t num_wgs_per_tile = kargs.tile_partitioner.estimate_num_wgs_per_tile();
// Launch kernel
return ck_tile::launch_kernel_time_mask(
const float time = ck_tile::launch_kernel_time_mask(
stream,
reset_data_buffers,
ck_tile::make_kernel<kBlockPerCu>(GemmKernel{{}}, grids, blocks, 0, kargs));
return std::tuple<float, ck_tile::index_t>{{time, num_wgs_per_tile}};
}}
}};
"""

View File

@@ -22,25 +22,25 @@ class GemmProfiler
// Overload for single kernel benchmarking
void benchmark(GemmProblem& gemm_problem,
std::function<float(const ck_tile::StreamKHostArgs&,
const ck_tile::stream_config&)> kernel_func)
std::function<std::tuple<float, ck_tile::index_t>(
const ck_tile::StreamKHostArgs&, const ck_tile::stream_config&)> kernel_func)
{
// Create a vector with a single callable that returns both name and time
std::vector<std::function<std::tuple<std::string, float>(ck_tile::StreamKHostArgs&,
const ck_tile::stream_config&)>>
// Create a vector with a single callable that returns name, time, and num_wgs_per_tile
std::vector<std::function<std::tuple<std::string, float, ck_tile::index_t>(
ck_tile::StreamKHostArgs&, const ck_tile::stream_config&)>>
callables;
callables.push_back(
[kernel_func](ck_tile::StreamKHostArgs& args, const ck_tile::stream_config& stream) {
float time = kernel_func(args, stream);
return std::make_tuple(std::string(KERNEL_NAME), time);
auto [time, num_wgs_per_tile] = kernel_func(args, stream);
return std::make_tuple(std::string(KERNEL_NAME), time, num_wgs_per_tile);
});
benchmark(gemm_problem, callables);
}
void benchmark(GemmProblem& gemm_problem,
std::vector<std::function<std::tuple<std::string, float>(
std::vector<std::function<std::tuple<std::string, float, ck_tile::index_t>(
ck_tile::StreamKHostArgs&, const ck_tile::stream_config&)>>& callables)
{
const ALayout layout_a = ALayout{};
@@ -160,9 +160,9 @@ class GemmProfiler
ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
const std::tuple<std::string, float>& kernel_run_result)
const std::tuple<std::string, float, ck_tile::index_t>& kernel_run_result)
{
auto [name, avg_time] = kernel_run_result;
auto [name, avg_time, num_wgs_per_tile] = kernel_run_result;
auto dp_persistent =
SelectedKernel::UsePersistentKernel ? "PersistentKernel" : "NonPersistentKernel";
@@ -196,8 +196,7 @@ class GemmProfiler
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool verified_correct =
!setting_.verify_ ||
compare(
name, gemm_problem.k_, gemm_problem.split_k_, c_m_n_dev_result, c_m_n_host_result);
compare(name, gemm_problem.k_, num_wgs_per_tile, c_m_n_dev_result, c_m_n_host_result);
if(verified_correct)
{