mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Profile optionally only a given instance.
This commit is contained in:
@@ -44,6 +44,7 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
|
||||
bool do_log,
|
||||
bool time_kernel,
|
||||
const ck::utils::conv::ConvParam& conv_param,
|
||||
std::optional<std::string> run_instance = std::nullopt,
|
||||
const OutElementOp out_element_op = OutElementOp{},
|
||||
index_t instance_index = -1)
|
||||
{
|
||||
@@ -232,6 +233,12 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
|
||||
return;
|
||||
}
|
||||
|
||||
if (run_instance.has_value() && !run_instance.value().empty() && op_ptr->GetTypeString().find(run_instance.value()) == std::string::npos)
|
||||
{
|
||||
// skip if run_instance is specified and does not match op name
|
||||
return;
|
||||
}
|
||||
|
||||
std::string op_name = op_ptr->GetTypeString();
|
||||
valids++;
|
||||
|
||||
|
||||
@@ -66,7 +66,9 @@ static void print_helper_msg()
|
||||
<< "arg6: initialization (0: no init, 1: integer value, 2: decimal value)\n"
|
||||
<< "arg7: print tensor value (0: no; 1: yes)\n"
|
||||
<< "arg8: time kernel (0: no, 1: yes)\n"
|
||||
<< ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl;
|
||||
<< ck::utils::conv::get_conv_param_parser_helper_msg()
|
||||
<< "last arg: run only given instance (string), optional\n"
|
||||
<< std::endl;
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@@ -90,14 +92,17 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
const bool time_kernel = std::stoi(argv[8]);
|
||||
const int num_dim_spatial = std::stoi(argv[9]);
|
||||
|
||||
// 9 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial
|
||||
if(argc != 9 + 1 + 4 + 6 * num_dim_spatial)
|
||||
// 9 for control, 1 for num_dim_spatial, 4 for G/N/K/C, 6 * num_dim_spatial, and optionally 1 for instance name
|
||||
const int base_number_of_args = 9 + 1 + 4 + 6 * num_dim_spatial;
|
||||
if(argc != base_number_of_args && argc != base_number_of_args + 1)
|
||||
{
|
||||
print_helper_msg();
|
||||
return 1;
|
||||
}
|
||||
|
||||
const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 10, argv);
|
||||
const std::string run_instance =
|
||||
(argc == base_number_of_args + 1) ? std::string(argv[base_number_of_args]) : "";
|
||||
|
||||
using F32 = float;
|
||||
using F16 = ck::half_t;
|
||||
@@ -178,7 +183,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
AComputeType,
|
||||
BComputeType,
|
||||
ck::index_t>(
|
||||
do_verification, init_method, do_log, time_kernel, params);
|
||||
do_verification, init_method, do_log, time_kernel, params, run_instance);
|
||||
|
||||
return pass ? 0 : 1;
|
||||
}
|
||||
@@ -194,7 +199,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
AComputeType,
|
||||
BComputeType,
|
||||
ck::long_index_t>(
|
||||
do_verification, init_method, do_log, time_kernel, params);
|
||||
do_verification, init_method, do_log, time_kernel, params, run_instance);
|
||||
|
||||
return pass ? 0 : 1;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user