Merge commit '8cbd09c84a3010b4b3dbe2604875772363e2396b' into develop

This commit is contained in:
assistant-librarian[bot]
2026-02-03 16:29:00 +00:00
parent ef6ce49698
commit fc1ff7a1f8
22 changed files with 522 additions and 406 deletions

View File

@@ -481,8 +481,9 @@ struct SelectedKernel {{
AccDataType,
TileShape,
GemmUniversalTraits>;
static float launch(const ck_tile::StreamKHostArgs& args, const ck_tile::stream_config& stream) {{
static std::tuple<float, ck_tile::index_t> launch(const ck_tile::StreamKHostArgs& args,
const ck_tile::stream_config& stream) {{
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
@@ -562,12 +563,16 @@ struct SelectedKernel {{
workspace_data.SetZero();
}}
}};
const ck_tile::index_t num_wgs_per_tile = kargs.tile_partitioner.estimate_num_wgs_per_tile();
// Launch kernel
return ck_tile::launch_kernel_time_mask(
const float time = ck_tile::launch_kernel_time_mask(
stream,
reset_data_buffers,
ck_tile::make_kernel<kBlockPerCu>(GemmKernel{{}}, grids, blocks, 0, kargs));
return std::tuple<float, ck_tile::index_t>{{time, num_wgs_per_tile}};
}}
}};
"""

View File

@@ -22,25 +22,25 @@ class GemmProfiler
// Overload for single kernel benchmarking
void benchmark(GemmProblem& gemm_problem,
std::function<float(const ck_tile::StreamKHostArgs&,
const ck_tile::stream_config&)> kernel_func)
std::function<std::tuple<float, ck_tile::index_t>(
const ck_tile::StreamKHostArgs&, const ck_tile::stream_config&)> kernel_func)
{
// Create a vector with a single callable that returns both name and time
std::vector<std::function<std::tuple<std::string, float>(ck_tile::StreamKHostArgs&,
const ck_tile::stream_config&)>>
// Create a vector with a single callable that returns name, time, and num_wgs_per_tile
std::vector<std::function<std::tuple<std::string, float, ck_tile::index_t>(
ck_tile::StreamKHostArgs&, const ck_tile::stream_config&)>>
callables;
callables.push_back(
[kernel_func](ck_tile::StreamKHostArgs& args, const ck_tile::stream_config& stream) {
float time = kernel_func(args, stream);
return std::make_tuple(std::string(KERNEL_NAME), time);
auto [time, num_wgs_per_tile] = kernel_func(args, stream);
return std::make_tuple(std::string(KERNEL_NAME), time, num_wgs_per_tile);
});
benchmark(gemm_problem, callables);
}
void benchmark(GemmProblem& gemm_problem,
std::vector<std::function<std::tuple<std::string, float>(
std::vector<std::function<std::tuple<std::string, float, ck_tile::index_t>(
ck_tile::StreamKHostArgs&, const ck_tile::stream_config&)>>& callables)
{
const ALayout layout_a = ALayout{};
@@ -160,9 +160,9 @@ class GemmProfiler
ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
const std::tuple<std::string, float>& kernel_run_result)
const std::tuple<std::string, float, ck_tile::index_t>& kernel_run_result)
{
auto [name, avg_time] = kernel_run_result;
auto [name, avg_time, num_wgs_per_tile] = kernel_run_result;
auto dp_persistent =
SelectedKernel::UsePersistentKernel ? "PersistentKernel" : "NonPersistentKernel";
@@ -196,8 +196,7 @@ class GemmProfiler
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool verified_correct =
!setting_.verify_ ||
compare(
name, gemm_problem.k_, gemm_problem.split_k_, c_m_n_dev_result, c_m_n_host_result);
compare(name, gemm_problem.k_, num_wgs_per_tile, c_m_n_dev_result, c_m_n_host_result);
if(verified_correct)
{