mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 04:49:54 +00:00
Merge commit '8cbd09c84a3010b4b3dbe2604875772363e2396b' into develop
This commit is contained in:
@@ -481,8 +481,9 @@ struct SelectedKernel {{
|
||||
AccDataType,
|
||||
TileShape,
|
||||
GemmUniversalTraits>;
|
||||
|
||||
static float launch(const ck_tile::StreamKHostArgs& args, const ck_tile::stream_config& stream) {{
|
||||
|
||||
static std::tuple<float, ck_tile::index_t> launch(const ck_tile::StreamKHostArgs& args,
|
||||
const ck_tile::stream_config& stream) {{
|
||||
constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
@@ -562,12 +563,16 @@ struct SelectedKernel {{
|
||||
workspace_data.SetZero();
|
||||
}}
|
||||
}};
|
||||
|
||||
const ck_tile::index_t num_wgs_per_tile = kargs.tile_partitioner.estimate_num_wgs_per_tile();
|
||||
|
||||
// Launch kernel
|
||||
return ck_tile::launch_kernel_time_mask(
|
||||
const float time = ck_tile::launch_kernel_time_mask(
|
||||
stream,
|
||||
reset_data_buffers,
|
||||
ck_tile::make_kernel<kBlockPerCu>(GemmKernel{{}}, grids, blocks, 0, kargs));
|
||||
|
||||
return std::tuple<float, ck_tile::index_t>{{time, num_wgs_per_tile}};
|
||||
}}
|
||||
}};
|
||||
"""
|
||||
|
||||
@@ -22,25 +22,25 @@ class GemmProfiler
|
||||
|
||||
// Overload for single kernel benchmarking
|
||||
void benchmark(GemmProblem& gemm_problem,
|
||||
std::function<float(const ck_tile::StreamKHostArgs&,
|
||||
const ck_tile::stream_config&)> kernel_func)
|
||||
std::function<std::tuple<float, ck_tile::index_t>(
|
||||
const ck_tile::StreamKHostArgs&, const ck_tile::stream_config&)> kernel_func)
|
||||
{
|
||||
// Create a vector with a single callable that returns both name and time
|
||||
std::vector<std::function<std::tuple<std::string, float>(ck_tile::StreamKHostArgs&,
|
||||
const ck_tile::stream_config&)>>
|
||||
// Create a vector with a single callable that returns name, time, and num_wgs_per_tile
|
||||
std::vector<std::function<std::tuple<std::string, float, ck_tile::index_t>(
|
||||
ck_tile::StreamKHostArgs&, const ck_tile::stream_config&)>>
|
||||
callables;
|
||||
|
||||
callables.push_back(
|
||||
[kernel_func](ck_tile::StreamKHostArgs& args, const ck_tile::stream_config& stream) {
|
||||
float time = kernel_func(args, stream);
|
||||
return std::make_tuple(std::string(KERNEL_NAME), time);
|
||||
auto [time, num_wgs_per_tile] = kernel_func(args, stream);
|
||||
return std::make_tuple(std::string(KERNEL_NAME), time, num_wgs_per_tile);
|
||||
});
|
||||
|
||||
benchmark(gemm_problem, callables);
|
||||
}
|
||||
|
||||
void benchmark(GemmProblem& gemm_problem,
|
||||
std::vector<std::function<std::tuple<std::string, float>(
|
||||
std::vector<std::function<std::tuple<std::string, float, ck_tile::index_t>(
|
||||
ck_tile::StreamKHostArgs&, const ck_tile::stream_config&)>>& callables)
|
||||
{
|
||||
const ALayout layout_a = ALayout{};
|
||||
@@ -160,9 +160,9 @@ class GemmProfiler
|
||||
ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_host_result,
|
||||
ck_tile::HostTensor<CDataType>& c_m_n_dev_result,
|
||||
const std::tuple<std::string, float>& kernel_run_result)
|
||||
const std::tuple<std::string, float, ck_tile::index_t>& kernel_run_result)
|
||||
{
|
||||
auto [name, avg_time] = kernel_run_result;
|
||||
auto [name, avg_time, num_wgs_per_tile] = kernel_run_result;
|
||||
auto dp_persistent =
|
||||
SelectedKernel::UsePersistentKernel ? "PersistentKernel" : "NonPersistentKernel";
|
||||
|
||||
@@ -196,8 +196,7 @@ class GemmProfiler
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
bool verified_correct =
|
||||
!setting_.verify_ ||
|
||||
compare(
|
||||
name, gemm_problem.k_, gemm_problem.split_k_, c_m_n_dev_result, c_m_n_host_result);
|
||||
compare(name, gemm_problem.k_, num_wgs_per_tile, c_m_n_dev_result, c_m_n_host_result);
|
||||
|
||||
if(verified_correct)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user