mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_TILE] Stream-K Tile Engine Test Config File Generation (#3662)
* Stream-K smoke test config file generation This change converts the stream-k smoke tests to use tile engine. Since the m, n, and k values dependent on the CU count of a device, the configs are generated during the Configuration Phase. * Compute GEMM reference on GPU * Remove redundant Stream-K tests Removing redundant tests that are now run via tile engine. * Fix relative and absolute tolerance calculation This change updates the Stream-K tile engine interface to ensure that num_wgs_per_tile is propaged and passed into the compare_results function to calculate the rel and abs tolerance. Before, split-k was used, which is incorrect for Stream-K since the split-k value is always 1. * Cleanup imports, types, and other misc items This commit makes the following changes: - Uses Typing module for nested type hints - Uses quotes around cu_count_arg argument in generate_configs.cmake in if statements - Adds explicit include for tuple in test_gemm_streamk_simple.cpp - Adds a type for the tiles argument in argparser to check argument validity * Use CU count as return value for better parsing * Add reduction tests for bf16, fp8, and bf8
This commit is contained in:
@@ -481,8 +481,9 @@ struct SelectedKernel {{
|
||||
AccDataType,
|
||||
TileShape,
|
||||
GemmUniversalTraits>;
|
||||
|
||||
static float launch(const ck_tile::StreamKHostArgs& args, const ck_tile::stream_config& stream) {{
|
||||
|
||||
static std::tuple<float, ck_tile::index_t> launch(const ck_tile::StreamKHostArgs& args,
|
||||
const ck_tile::stream_config& stream) {{
|
||||
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
@@ -562,12 +563,16 @@ struct SelectedKernel {{
|
||||
workspace_data.SetZero();
|
||||
}}
|
||||
}};
|
||||
|
||||
const ck_tile::index_t num_wgs_per_tile = kargs.tile_partitioner.estimate_num_wgs_per_tile();
|
||||
|
||||
// Launch kernel
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
const float time = ck_tile::launch_kernel_time_mask(
|
||||
stream,
|
||||
reset_data_buffers,
|
||||
ck_tile::make_kernel<kBlockPerCu>(GemmKernel{{}}, grids, blocks, 0, kargs));
|
||||
|
||||
return std::tuple<float, ck_tile::index_t>{{time, num_wgs_per_tile}};
|
||||
}}
|
||||
}};
|
||||
"""
|
||||
|
||||
@@ -22,25 +22,25 @@ class GemmProfiler
|
||||
|
||||
// Overload for single kernel benchmarking
|
||||
void benchmark(GemmProblem& gemm_problem,
|
||||
std::function<float(const ck_tile::StreamKHostArgs&,
|
||||
const ck_tile::stream_config&)> kernel_func)
|
||||
std::function<std::tuple<float, ck_tile::index_t>(
|
||||
const ck_tile::StreamKHostArgs&, const ck_tile::stream_config&)> kernel_func)
|
||||
{
|
||||
// Create a vector with a single callable that returns both name and time
|
||||
std::vector<std::function<std::tuple<std::string, float>(ck_tile::StreamKHostArgs&,
|
||||
const ck_tile::stream_config&)>>
|
||||
// Create a vector with a single callable that returns name, time, and num_wgs_per_tile
|
||||
std::vector<std::function<std::tuple<std::string, float, ck_tile::index_t>(
|
||||
ck_tile::StreamKHostArgs&, const ck_tile::stream_config&)>>
|
||||
callables;
|
||||
|
||||
callables.push_back(
|
||||
[kernel_func](ck_tile::StreamKHostArgs& args, const ck_tile::stream_config& stream) {
|
||||
float time = kernel_func(args, stream);
|
||||
return std::make_tuple(std::string(KERNEL_NAME), time);
|
||||
auto [time, num_wgs_per_tile] = kernel_func(args, stream);
|
||||
return std::make_tuple(std::string(KERNEL_NAME), time, num_wgs_per_tile);
|
||||
});
|
||||
|
||||
benchmark(gemm_problem, callables);
|
||||
}
|
||||
|
||||
void benchmark(GemmProblem& gemm_problem,
|
||||
std::vector<std::function<std::tuple<std::string, float>(
|
||||
std::vector<std::function<std::tuple<std::string, float, ck_tile::index_t>(
|
||||
ck_tile::StreamKHostArgs&, const ck_tile::stream_config&)>>& callables)
|
||||
{
|
||||
const ALayout layout_a = ALayout{};
|
||||
@@ -160,9 +160,9 @@ class GemmProfiler
|
||||
ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
const std::tuple<std::string, float>& kernel_run_result)
|
||||
const std::tuple<std::string, float, ck_tile::index_t>& kernel_run_result)
|
||||
{
|
||||
auto [name, avg_time] = kernel_run_result;
|
||||
auto [name, avg_time, num_wgs_per_tile] = kernel_run_result;
|
||||
auto dp_persistent =
|
||||
SelectedKernel::UsePersistentKernel ? "PersistentKernel" : "NonPersistentKernel";
|
||||
|
||||
@@ -196,8 +196,7 @@ class GemmProfiler
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
bool verified_correct =
|
||||
!setting_.verify_ ||
|
||||
compare(
|
||||
name, gemm_problem.k_, gemm_problem.split_k_, c_m_n_dev_result, c_m_n_host_result);
|
||||
compare(name, gemm_problem.k_, num_wgs_per_tile, c_m_n_dev_result, c_m_n_host_result);
|
||||
|
||||
if(verified_correct)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user