mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 09:45:56 +00:00
[rocm-libraries] ROCm/rocm-libraries#4800 (commit 9dcf0cf)
[CK Profiler] Instance selection for grouped conv profilers (#4800) ## Motivation This PR adds instance selection support for ckProfiler grouped convolution operations (forward, backward data, backward weight), allowing users to run specific kernel instances rather than sweeping all available instances. When profiling or debugging convolution kernels, users often need to test specific kernel configurations without running the full instance sweep. This is particularly useful for: - Debugging a specific failing instance - Profiling a known-best configuration - Quick validation during development ## Technical Details **Features added**: - `--instance <id>` flag to run only the N-th valid instance (0-indexed) - `--list-instances` flag to list all valid instances without running any kernels - Named arguments can appear anywhere on the command line - Best instance index is now printed with results for reference - Python script support via `-ii` / `--instance_index` arguments **Design decisions**: - Named arguments (`--instance`, `--list-instances`) instead of positional to avoid conflicts with existing parameters - Instance index refers to the N-th valid instance (0-indexed), not the global instance index - Auto-disable verification when `--list-instances` is used for fast enumeration - Shared utilities in `profiler_arg_utils.hpp` to deduplicate parsing logic ## Test Plan Manual testing with various scenarios: List all valid instances: ```bash ./bin/ckProfiler grouped_conv_fwd <usual args> --list-instances ``` Run only instance 5: ```bash ./bin/ckProfiler grouped_conv_fwd <usual args> --instance 5 ``` Test cases: - Single instance selection - List instances mode - Out-of-bounds instance index (verified warning messages) - No instance flag (runs all instances - default behavior) - All three operations (fwd, bwd_data, bwd_weight) ## Test Result All test scenarios passed: - Instance selection correctly filters kernel executions - List mode enumerates valid instances without running kernels - Invalid indices produce appropriate warnings without crashing - Default behavior (all instances) unchanged when flags not provided - Consistent behavior across all three grouped convolution operations ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
fff9f72ffb
commit
1cd031c21d
@@ -5,8 +5,8 @@
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "profiler/profile_grouped_conv_bwd_data_impl.hpp"
|
||||
#include "profiler/profiler_arg_utils.hpp"
|
||||
#include "profiler_operation_registry.hpp"
|
||||
|
||||
namespace {
|
||||
@@ -48,7 +48,10 @@ static void print_helper_msg()
|
||||
<< "arg6: print tensor value (0: no; 1: yes)\n"
|
||||
<< "arg7: time kernel (0: no, 1: yes)\n"
|
||||
<< ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl
|
||||
<< "arg8: split-K (0: internally computed split-K value; 1, 2, 4, 8, 16, 32, 64, 128: set k batches explicitly)\n";
|
||||
<< "arg8: split-K (0: internally computed split-K value; 1, 2, 4, 8, 16, 32, 64, 128: set k batches explicitly)\n"
|
||||
<< "\nOptional arguments:\n"
|
||||
<< " --instance <id> Run only the specified instance (0-indexed among valid instances)\n"
|
||||
<< " --list-instances List all valid instances without running\n";
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@@ -56,8 +59,17 @@ static void print_helper_msg()
|
||||
|
||||
int profile_grouped_conv_bwd_data(int argc, char* argv[])
|
||||
{
|
||||
// Parse optional named arguments first
|
||||
ck::index_t instance_index = -1;
|
||||
bool list_instances = false;
|
||||
ck::profiler::parse_named_args(argc, argv, instance_index, list_instances);
|
||||
const int named_arg_count = ck::profiler::count_named_args(argc, argv);
|
||||
|
||||
// Adjust argc for positional argument checking
|
||||
const int positional_argc = argc - named_arg_count;
|
||||
|
||||
// 8 for control, 1 for num_dim_spatial
|
||||
if(argc < 9)
|
||||
if(positional_argc < 9)
|
||||
{
|
||||
print_helper_msg();
|
||||
return 1;
|
||||
@@ -72,7 +84,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
|
||||
const int num_dim_spatial = std::stoi(argv[8]);
|
||||
|
||||
// 8 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial, 1 for split-K
|
||||
if(argc != 8 + 1 + 4 + 6 * num_dim_spatial + 1)
|
||||
if(positional_argc != 8 + 1 + 4 + 6 * num_dim_spatial + 1)
|
||||
{
|
||||
print_helper_msg();
|
||||
return 1;
|
||||
@@ -111,15 +123,22 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
|
||||
using InDataType = decltype(in_type);
|
||||
using ComputeDataType = decltype(compute_type);
|
||||
|
||||
bool pass = ck::profiler::profile_grouped_conv_bwd_data_impl<NDimSpatial,
|
||||
OutLayout,
|
||||
WeiLayout,
|
||||
InLayout,
|
||||
OutDataType,
|
||||
WeiDataType,
|
||||
InDataType,
|
||||
ComputeDataType>(
|
||||
do_verification, init_method, do_log, time_kernel, params, split_k);
|
||||
bool pass =
|
||||
ck::profiler::profile_grouped_conv_bwd_data_impl<NDimSpatial,
|
||||
OutLayout,
|
||||
WeiLayout,
|
||||
InLayout,
|
||||
OutDataType,
|
||||
WeiDataType,
|
||||
InDataType,
|
||||
ComputeDataType>(do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
params,
|
||||
split_k,
|
||||
instance_index,
|
||||
list_instances);
|
||||
|
||||
return pass ? 0 : 1;
|
||||
};
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include <numeric>
|
||||
|
||||
#include "profiler/profile_grouped_conv_bwd_weight_impl.hpp"
|
||||
#include "profiler/profiler_arg_utils.hpp"
|
||||
#include "profiler_operation_registry.hpp"
|
||||
|
||||
namespace {
|
||||
@@ -61,6 +62,10 @@ static void print_helper_msg()
|
||||
<< ck::utils::conv::get_conv_param_parser_helper_msg()
|
||||
<< " SplitK (-1 for internally computed split-K value, positive value to set k "
|
||||
"batches explicitly, or 'all' to test all internal split-K values)\n"
|
||||
<< "\nOptional arguments:\n"
|
||||
<< " --instance <id> Run only the specified instance (0-indexed among valid "
|
||||
"instances)\n"
|
||||
<< " --list-instances List all valid instances without running\n"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
@@ -68,8 +73,17 @@ static void print_helper_msg()
|
||||
|
||||
int profile_grouped_conv_bwd_weight(int argc, char* argv[])
|
||||
{
|
||||
// Parse optional named arguments first
|
||||
ck::index_t instance_index = -1;
|
||||
bool list_instances = false;
|
||||
ck::profiler::parse_named_args(argc, argv, instance_index, list_instances);
|
||||
const int named_arg_count = ck::profiler::count_named_args(argc, argv);
|
||||
|
||||
// Adjust argc for positional argument checking
|
||||
const int positional_argc = argc - named_arg_count;
|
||||
|
||||
// 8 for control, 1 for num_dim_spatial
|
||||
if(argc < 9)
|
||||
if(positional_argc < 9)
|
||||
{
|
||||
print_helper_msg();
|
||||
return 1;
|
||||
@@ -84,7 +98,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
|
||||
const int num_dim_spatial = std::stoi(argv[8]);
|
||||
|
||||
// 8 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial, 1 for split-K
|
||||
if(argc != 8 + 1 + 4 + 6 * num_dim_spatial + 1)
|
||||
if(positional_argc != 8 + 1 + 4 + 6 * num_dim_spatial + 1)
|
||||
{
|
||||
print_helper_msg();
|
||||
return 1;
|
||||
@@ -129,16 +143,23 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
|
||||
using ComputeTypeA = decltype(compute_type_a);
|
||||
using ComputeTypeB = decltype(compute_type_b);
|
||||
|
||||
bool pass = ck::profiler::profile_grouped_conv_bwd_weight_impl<NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>(
|
||||
do_verification, init_method, do_log, time_kernel, params, split_k);
|
||||
bool pass =
|
||||
ck::profiler::profile_grouped_conv_bwd_weight_impl<NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>(do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
params,
|
||||
split_k,
|
||||
instance_index,
|
||||
list_instances);
|
||||
|
||||
return pass ? 0 : 1;
|
||||
};
|
||||
|
||||
@@ -5,8 +5,8 @@
|
||||
#include <numeric>
|
||||
#include <initializer_list>
|
||||
#include <cstdlib>
|
||||
|
||||
#include "profiler/profile_grouped_conv_fwd_impl.hpp"
|
||||
#include "profiler/profiler_arg_utils.hpp"
|
||||
#include "profiler_operation_registry.hpp"
|
||||
|
||||
namespace {
|
||||
@@ -66,7 +66,10 @@ static void print_helper_msg()
|
||||
<< "arg6: initialization (0: no init, 1: integer value, 2: decimal value)\n"
|
||||
<< "arg7: print tensor value (0: no; 1: yes)\n"
|
||||
<< "arg8: time kernel (0: no, 1: yes)\n"
|
||||
<< ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl;
|
||||
<< ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl
|
||||
<< "\nOptional arguments:\n"
|
||||
<< " --instance <id> Run only the specified instance (0-indexed among valid instances)\n"
|
||||
<< " --list-instances List all valid instances without running\n";
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
@@ -74,8 +77,17 @@ static void print_helper_msg()
|
||||
|
||||
int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
{
|
||||
// Parse optional named arguments first
|
||||
ck::index_t instance_index = -1;
|
||||
bool list_instances = false;
|
||||
ck::profiler::parse_named_args(argc, argv, instance_index, list_instances);
|
||||
const int named_arg_count = ck::profiler::count_named_args(argc, argv);
|
||||
|
||||
// Adjust argc for positional argument checking
|
||||
const int positional_argc = argc - named_arg_count;
|
||||
|
||||
// 8 for control, 1 for num_dim_spatial
|
||||
if(argc < 10)
|
||||
if(positional_argc < 10)
|
||||
{
|
||||
print_helper_msg();
|
||||
return 1;
|
||||
@@ -91,7 +103,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
const int num_dim_spatial = std::stoi(argv[9]);
|
||||
|
||||
// 9 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial
|
||||
if(argc != 9 + 1 + 4 + 6 * num_dim_spatial)
|
||||
if(positional_argc != 9 + 1 + 4 + 6 * num_dim_spatial)
|
||||
{
|
||||
print_helper_msg();
|
||||
return 1;
|
||||
@@ -178,7 +190,14 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
AComputeType,
|
||||
BComputeType,
|
||||
ck::index_t>(
|
||||
do_verification, init_method, do_log, time_kernel, params);
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
params,
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
instance_index,
|
||||
list_instances);
|
||||
|
||||
return pass ? 0 : 1;
|
||||
}
|
||||
@@ -194,7 +213,14 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
|
||||
AComputeType,
|
||||
BComputeType,
|
||||
ck::long_index_t>(
|
||||
do_verification, init_method, do_log, time_kernel, params);
|
||||
do_verification,
|
||||
init_method,
|
||||
do_log,
|
||||
time_kernel,
|
||||
params,
|
||||
ck::tensor_operation::element_wise::PassThrough{},
|
||||
instance_index,
|
||||
list_instances);
|
||||
|
||||
return pass ? 0 : 1;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user