Profile optionally only a given instance.

This commit is contained in:
Ville Pietilä
2026-02-04 06:31:36 -05:00
parent f6f381dbd4
commit 2cfc4209bb
2 changed files with 17 additions and 5 deletions

View File

@@ -44,6 +44,7 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
bool do_log,
bool time_kernel,
const ck::utils::conv::ConvParam& conv_param,
std::optional<std::string> run_instance = std::nullopt,
const OutElementOp out_element_op = OutElementOp{},
index_t instance_index = -1)
{
@@ -232,6 +233,12 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
return;
}
if (run_instance.has_value() && !run_instance.value().empty() && op_ptr->GetTypeString().find(run_instance.value()) == std::string::npos)
{
// skip if run_instance is specified and does not match op name
return;
}
std::string op_name = op_ptr->GetTypeString();
valids++;

View File

@@ -66,7 +66,9 @@ static void print_helper_msg()
<< "arg6: initialization (0: no init, 1: integer value, 2: decimal value)\n"
<< "arg7: print tensor value (0: no; 1: yes)\n"
<< "arg8: time kernel (0: no, 1: yes)\n"
<< ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl;
<< ck::utils::conv::get_conv_param_parser_helper_msg()
<< "last arg: run only given instance (string), optional\n"
<< std::endl;
// clang-format on
}
@@ -90,14 +92,17 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
const bool time_kernel = std::stoi(argv[8]);
const int num_dim_spatial = std::stoi(argv[9]);
// 9 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial
if(argc != 9 + 1 + 4 + 6 * num_dim_spatial)
// 9 for control, 1 for num_dim_spatial, 4 for G/N/K/C, 6 * num_dim_spatial, and optionally 1 for instance name
const int base_number_of_args = 9 + 1 + 4 + 6 * num_dim_spatial;
if(argc != base_number_of_args && argc != base_number_of_args + 1)
{
print_helper_msg();
return 1;
}
const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 10, argv);
const std::string run_instance =
(argc == base_number_of_args + 1) ? std::string(argv[base_number_of_args]) : "";
using F32 = float;
using F16 = ck::half_t;
@@ -178,7 +183,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
AComputeType,
BComputeType,
ck::index_t>(
do_verification, init_method, do_log, time_kernel, params);
do_verification, init_method, do_log, time_kernel, params, run_instance);
return pass ? 0 : 1;
}
@@ -194,7 +199,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
AComputeType,
BComputeType,
ck::long_index_t>(
do_verification, init_method, do_log, time_kernel, params);
do_verification, init_method, do_log, time_kernel, params, run_instance);
return pass ? 0 : 1;
}