[rocm-libraries] ROCm/rocm-libraries#4800 (commit 9dcf0cf)

[CK Profiler] Instance selection for grouped conv profilers
 (#4800)

## Motivation

This PR adds instance selection support for ckProfiler grouped
convolution operations (forward, backward data, backward weight),
allowing users to run specific kernel instances rather than sweeping all
available instances.

When profiling or debugging convolution kernels, users often need to
test specific kernel configurations without running the full instance
sweep. This is particularly useful for:
- Debugging a specific failing instance
- Profiling a known-best configuration
- Quick validation during development

## Technical Details

**Features added**:
- `--instance <id>` flag to run only the N-th valid instance (0-indexed)
- `--list-instances` flag to list all valid instances without running
any kernels
- Named arguments can appear anywhere on the command line
- Best instance index is now printed with results for reference
- Python script support via `-ii` / `--instance_index` arguments

**Design decisions**:
- Named arguments (`--instance`, `--list-instances`) instead of
positional to avoid conflicts with existing parameters
- Instance index refers to the N-th valid instance (0-indexed), not the
global instance index
- Auto-disable verification when `--list-instances` is used for fast
enumeration
- Shared utilities in `profiler_arg_utils.hpp` to deduplicate parsing
logic

## Test Plan

Manual testing with various scenarios:

List all valid instances:
```bash
./bin/ckProfiler grouped_conv_fwd <usual args> --list-instances
```

Run only instance 5:
```bash
./bin/ckProfiler grouped_conv_fwd <usual args> --instance 5
```

Test cases:
- Single instance selection
- List instances mode
- Out-of-bounds instance index (verified warning messages)
- No instance flag (runs all instances - default behavior)
- All three operations (fwd, bwd_data, bwd_weight)

## Test Result

All test scenarios passed:
- Instance selection correctly filters kernel executions
- List mode enumerates valid instances without running kernels
- Invalid indices produce appropriate warnings without crashing
- Default behavior (all instances) unchanged when flags not provided
- Consistent behavior across all three grouped convolution operations

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Johannes Graner
2026-03-03 15:33:20 +00:00
committed by assistant-librarian[bot]
parent fff9f72ffb
commit 1cd031c21d
8 changed files with 361 additions and 89 deletions

View File

@@ -39,7 +39,8 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
bool time_kernel,
const ck::utils::conv::ConvParam& conv_param,
ck::index_t split_k = 1,
index_t instance_index = -1)
index_t instance_index = -1,
bool list_instances = false)
{
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
@@ -78,6 +79,10 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
DeviceMem wei_device_buf(sizeof(WeiDataType) * wei_element_space_size);
DeviceMem in_device_buf(sizeof(InDataType) * in_element_space_size);
// Don't create reference if we're only listing instances
if(list_instances)
do_verification = 0;
// Initialize tensors based on do_verification:
// - do_verification=2: GPU-side initialization
// - do_verification=0,1: CPU-side initialization
@@ -198,15 +203,17 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
}
std::string best_op_name;
float best_avg_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
ck::index_t best_split_k = 1;
float best_avg_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
ck::index_t best_split_k = 1;
ck::index_t best_instance_index = 0;
// profile device op instances
bool pass = true;
index_t num_kernel = 0;
auto run_impl = [&](auto& op_ptr, auto& argument_ptr, const index_t& split_k_for_run) {
bool pass = true;
index_t num_kernel = 0;
index_t valid_instances = 0;
auto run_impl = [&](auto& op_ptr, auto& argument_ptr, const index_t& split_k_for_run) {
// workspace_sz will be equal to 0 for other layout than NGCHW
const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
DeviceMem workspace_dev(workspace_sz);
@@ -215,11 +222,23 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
num_kernel++;
if((instance_index != -1) && (instance_index + 1 != num_kernel))
// List instances mode - just print and continue
if(list_instances)
{
// skip test if instance_index is specified
std::cout << "[" << (num_kernel - 1) << "] " << op_ptr->GetTypeString()
<< " (SplitK=" << split_k_for_run << ")" << std::endl;
return;
}
// Skip if a specific instance was requested and this isn't it
const bool running_specific_instance = (instance_index != -1);
const bool current_is_target = (num_kernel - 1 == instance_index);
if(running_specific_instance && !current_is_target)
{
return;
}
valid_instances++;
std::string op_name = op_ptr->GetTypeString();
auto invoker_ptr = op_ptr->MakeInvokerPointer();
@@ -240,11 +259,12 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
if(tflops > best_tflops)
{
best_op_name = op_name;
best_tflops = tflops;
best_avg_time = avg_time;
best_gb_per_sec = gb_per_sec;
best_split_k = split_k_for_run;
best_op_name = op_name;
best_tflops = tflops;
best_avg_time = avg_time;
best_gb_per_sec = gb_per_sec;
best_split_k = split_k_for_run;
best_instance_index = num_kernel - 1;
}
// Synchronize before verification to ensure kernel has completed
@@ -257,12 +277,12 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
{
// GPU verification path
using ComputeType_ = std::conditional_t<sizeof(OutDataType) < sizeof(WeiDataType),
OutDataType,
WeiDataType>;
OutDataType,
WeiDataType>;
using ComputeType =
std::conditional_t<sizeof(ComputeType_) < sizeof(ComputeDataType),
ComputeType_,
ComputeDataType>;
ComputeType_,
ComputeDataType>;
using AccDataType =
std::conditional_t<std::is_same_v<ComputeType, int8_t>, int32_t, float>;
@@ -312,12 +332,12 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
in_device_buf.FromDevice(in_device.mData.data());
using ComputeType_ = std::conditional_t<sizeof(OutDataType) < sizeof(WeiDataType),
OutDataType,
WeiDataType>;
OutDataType,
WeiDataType>;
using ComputeType =
std::conditional_t<sizeof(ComputeType_) < sizeof(ComputeDataType),
ComputeType_,
ComputeDataType>;
ComputeType_,
ComputeDataType>;
using AccDataType =
std::conditional_t<std::is_same_v<ComputeType, int8_t>, int32_t, float>;
const index_t num_accums = conv_param.K_;
@@ -359,7 +379,7 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
}
}
}
else
else if(list_instances || instance_index == -1)
{
std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl;
}
@@ -417,6 +437,11 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
split_k_list = {split_k};
}
if(list_instances)
{
std::cout << "\nValid instances for this problem:" << std::endl;
}
for(auto& op_ptr : op_ptrs)
{
for(std::size_t split_k_id = 0; split_k_id < split_k_list.size(); split_k_id++)
@@ -447,9 +472,25 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
}
}
std::cout << "Best configuration parameters:" << "\nname: " << best_op_name
<< "\navg_time: " << best_avg_time << "\ntflops: " << best_tflops
<< "\nGB/s: " << best_gb_per_sec << ", SplitK " << best_split_k << std::endl;
if(list_instances)
{
std::cout << "\nTotal: " << num_kernel << " valid instances" << std::endl;
return true;
}
printf("\033[36mvalids: %ld\033[0m\n", static_cast<long>(valid_instances));
if(instance_index != -1 && valid_instances == 0)
{
std::cerr << "Error: instance_index " << instance_index
<< " exceeds the number of valid instances (" << num_kernel << ")" << std::endl;
return false;
}
std::cout << "Best configuration parameters:" << "\nname: " << best_op_name << " (instance "
<< best_instance_index << ")" << "\navg_time: " << best_avg_time
<< "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << ", SplitK "
<< best_split_k << std::endl;
if(instance_index != -1)
{

View File

@@ -44,7 +44,8 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
bool time_kernel,
const ck::utils::conv::ConvParam& conv_param,
const std::string& split_k,
index_t instance_index = -1)
index_t instance_index = -1,
bool list_instances = false)
{
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
@@ -83,6 +84,10 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
DeviceMem wei_device_buf(sizeof(WeiDataType) * weight_element_space_size);
DeviceMem out_device_buf(sizeof(OutDataType) * output_element_space_size);
// Don't create reference if we're only listing instances
if(list_instances)
do_verification = 0;
// Initialize tensors based on do_verification:
// - do_verification=2: GPU-side initialization
// - do_verification=0,1: CPU-side initialization
@@ -213,6 +218,8 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
float best_tflops = 0;
float best_gb_per_sec = 0;
std::string best_split_k("1");
index_t best_instance_index = 0;
index_t valid_instances = 0;
// profile device Conv instances
bool all_pass = true;
@@ -257,6 +264,11 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
}
}
if(list_instances)
{
std::cout << "\nValid instances for this problem:" << std::endl;
}
index_t num_kernel = 0;
for(auto& op_ptr : op_ptrs)
{
@@ -316,12 +328,24 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
num_kernel++;
if((instance_index != -1) && (instance_index + 1 != num_kernel))
// List instances mode - just print and continue
if(list_instances)
{
// skip test if instance_index is specified
std::cout << "[" << (num_kernel - 1) << "] " << op_ptr->GetTypeString()
<< " (SplitK=" << split_k_param_str << ")" << std::endl;
continue;
}
// Skip if a specific instance was requested and this isn't it
const bool running_specific_instance = (instance_index != -1);
const bool current_is_target = (num_kernel - 1 == instance_index);
if(running_specific_instance && !current_is_target)
{
continue;
}
valid_instances++;
std::string op_name = op_ptr->GetTypeString();
auto invoker_ptr = op_ptr->MakeInvokerPointer();
@@ -341,11 +365,12 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
if(tflops > best_tflops)
{
best_op_name = op_name;
best_tflops = tflops;
best_avg_time = avg_time;
best_gb_per_sec = gb_per_sec;
best_split_k = split_k_param_str;
best_op_name = op_name;
best_tflops = tflops;
best_avg_time = avg_time;
best_gb_per_sec = gb_per_sec;
best_split_k = split_k_param_str;
best_instance_index = num_kernel - 1;
}
// Synchronize before verification to ensure kernel has completed
@@ -491,7 +516,7 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
}
}
}
else
else if(list_instances || instance_index == -1)
{
std::cout << op_ptr->GetTypeString() << " does not support this problem"
<< std::endl;
@@ -499,11 +524,25 @@ bool profile_grouped_conv_bwd_weight_impl(int do_verification,
}
}
printf("\033[36mvalids: %d\033[0m\n", num_kernel);
if(list_instances)
{
std::cout << "\nTotal: " << num_kernel << " valid instances" << std::endl;
return true;
}
std::cout << "Best configuration parameters:" << "\nname: " << best_op_name
<< "\navg_time: " << best_avg_time << "\ntflops: " << best_tflops
<< "\nGB/s: " << best_gb_per_sec << ", SplitK " << best_split_k << std::endl;
printf("\033[36mvalids: %ld\033[0m\n", static_cast<long>(valid_instances));
if(instance_index != -1 && valid_instances == 0)
{
std::cerr << "Error: instance_index " << instance_index
<< " exceeds the number of valid instances (" << num_kernel << ")" << std::endl;
return false;
}
std::cout << "Best configuration parameters:" << "\nname: " << best_op_name << " (instance "
<< best_instance_index << ")" << "\navg_time: " << best_avg_time
<< "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << ", SplitK "
<< best_split_k << std::endl;
if(instance_index != -1)
{

View File

@@ -47,7 +47,8 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
bool time_kernel,
const ck::utils::conv::ConvParam& conv_param,
const OutElementOp out_element_op = OutElementOp{},
index_t instance_index = -1)
index_t instance_index = -1,
bool list_instances = false)
{
using InElementOp = ck::tensor_operation::element_wise::PassThrough;
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
@@ -108,6 +109,10 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
DeviceMem wei_device_buf(sizeof(WeiDataType) * weight_size);
DeviceMem out_device_buf(sizeof(OutDataType) * output_size);
// Don't create reference if we're only listing instances
if(list_instances)
do_verification = 0;
// Initialize tensors based on do_verification:
// - do_verification=2: GPU-side initialization
// - do_verification=0,1: CPU-side initialization
@@ -210,11 +215,12 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
}
std::string best_op_name;
float best_avg_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
index_t num_kernel = 0;
int valids = 0;
float best_avg_time = 0;
float best_tflops = 0;
float best_gb_per_sec = 0;
index_t num_kernel = 0;
index_t valid_instances = 0;
index_t best_instance_index = 0;
// profile device op instances
bool pass = true;
@@ -228,14 +234,25 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
if(op_ptr->IsSupportedArgument(argument_ptr.get()))
{
num_kernel++;
if((instance_index != -1) && (instance_index + 1 != num_kernel))
// List instances mode - just print and continue
if(list_instances)
{
std::cout << "[" << (num_kernel - 1) << "] " << op_ptr->GetTypeString()
<< std::endl;
return;
}
// Skip if a specific instance was requested and this isn't it
const bool running_specific_instance = (instance_index != -1);
const bool current_is_target = (num_kernel - 1 == instance_index);
if(running_specific_instance && !current_is_target)
{
// skip test if instance_index is specified
return;
}
std::string op_name = op_ptr->GetTypeString();
valids++;
valid_instances++;
out_device_buf.SetZero();
@@ -256,10 +273,11 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
if(tflops > best_tflops)
{
best_op_name = op_name;
best_tflops = tflops;
best_avg_time = avg_time;
best_gb_per_sec = gb_per_sec;
best_op_name = op_name;
best_tflops = tflops;
best_avg_time = avg_time;
best_gb_per_sec = gb_per_sec;
best_instance_index = num_kernel - 1;
}
// Synchronize before verification to ensure kernel has completed
@@ -330,7 +348,7 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
}
}
}
else
else if(list_instances || instance_index == -1)
{
std::cout << op_ptr->GetTypeString() << " does not support this problem" << std::endl;
}
@@ -357,6 +375,11 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
std::cout << "ckProfiler found " << op_ptrs.size() << " instances" << std::endl;
if(list_instances)
{
std::cout << "\nValid instances for this problem:" << std::endl;
}
for(auto& op_ptr : op_ptrs)
{
auto argument_ptr = op_ptr->MakeArgumentPointer(in_device_buf.GetDeviceBuffer(),
@@ -382,11 +405,24 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
run_impl(op_ptr, argument_ptr);
}
printf("\033[36mvalids: %d\033[0m\n", valids);
if(list_instances)
{
std::cout << "\nTotal: " << num_kernel << " valid instances" << std::endl;
return true;
}
std::cout << "Best configuration parameters:" << "\nname: " << best_op_name
<< "\navg_time: " << best_avg_time << "\ntflops: " << best_tflops
<< "\nGB/s: " << best_gb_per_sec << std::endl;
printf("\033[36mvalids: %ld\033[0m\n", static_cast<long>(valid_instances));
if(instance_index != -1 && valid_instances == 0)
{
std::cerr << "Error: instance_index " << instance_index
<< " exceeds the number of valid instances (" << num_kernel << ")" << std::endl;
return false;
}
std::cout << "Best configuration parameters:" << "\nname: " << best_op_name << " (instance "
<< best_instance_index << ")" << "\navg_time: " << best_avg_time
<< "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl;
if(instance_index != -1)
{
std::cout << "grouped_conv_fwd_instance (" << instance_index << "/" << num_kernel

View File

@@ -0,0 +1,54 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <cstring>
#include <string>
#include "ck/ck.hpp"
namespace ck {
namespace profiler {
// Parse optional named arguments from argv
// Modifies instance_index and list_instances based on found arguments
inline void
parse_named_args(int argc, char* argv[], ck::index_t& instance_index, bool& list_instances)
{
instance_index = -1;
list_instances = false;
for(int i = 1; i < argc; ++i)
{
if(std::strcmp(argv[i], "--instance") == 0 && i + 1 < argc)
{
instance_index = std::stoi(argv[i + 1]);
}
else if(std::strcmp(argv[i], "--list-instances") == 0)
{
list_instances = true;
}
}
}
// Count named arguments to adjust argc for positional arg checking
inline int count_named_args(int argc, char* argv[])
{
int count = 0;
for(int i = 1; i < argc; ++i)
{
if(std::strcmp(argv[i], "--instance") == 0)
{
count += 2;
++i; // skip the value
}
else if(std::strcmp(argv[i], "--list-instances") == 0)
{
count += 1;
}
}
return count;
}
} // namespace profiler
} // namespace ck

View File

@@ -5,8 +5,8 @@
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "profiler/profile_grouped_conv_bwd_data_impl.hpp"
#include "profiler/profiler_arg_utils.hpp"
#include "profiler_operation_registry.hpp"
namespace {
@@ -48,7 +48,10 @@ static void print_helper_msg()
<< "arg6: print tensor value (0: no; 1: yes)\n"
<< "arg7: time kernel (0: no, 1: yes)\n"
<< ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl
<< "arg8: split-K (0: internally computed split-K value; 1, 2, 4, 8, 16, 32, 64, 128: set k batches explicitly)\n";
<< "arg8: split-K (0: internally computed split-K value; 1, 2, 4, 8, 16, 32, 64, 128: set k batches explicitly)\n"
<< "\nOptional arguments:\n"
<< " --instance <id> Run only the specified instance (0-indexed among valid instances)\n"
<< " --list-instances List all valid instances without running\n";
// clang-format on
}
@@ -56,8 +59,17 @@ static void print_helper_msg()
int profile_grouped_conv_bwd_data(int argc, char* argv[])
{
// Parse optional named arguments first
ck::index_t instance_index = -1;
bool list_instances = false;
ck::profiler::parse_named_args(argc, argv, instance_index, list_instances);
const int named_arg_count = ck::profiler::count_named_args(argc, argv);
// Adjust argc for positional argument checking
const int positional_argc = argc - named_arg_count;
// 8 for control, 1 for num_dim_spatial
if(argc < 9)
if(positional_argc < 9)
{
print_helper_msg();
return 1;
@@ -72,7 +84,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
const int num_dim_spatial = std::stoi(argv[8]);
// 8 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial, 1 for split-K
if(argc != 8 + 1 + 4 + 6 * num_dim_spatial + 1)
if(positional_argc != 8 + 1 + 4 + 6 * num_dim_spatial + 1)
{
print_helper_msg();
return 1;
@@ -111,15 +123,22 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
using InDataType = decltype(in_type);
using ComputeDataType = decltype(compute_type);
bool pass = ck::profiler::profile_grouped_conv_bwd_data_impl<NDimSpatial,
OutLayout,
WeiLayout,
InLayout,
OutDataType,
WeiDataType,
InDataType,
ComputeDataType>(
do_verification, init_method, do_log, time_kernel, params, split_k);
bool pass =
ck::profiler::profile_grouped_conv_bwd_data_impl<NDimSpatial,
OutLayout,
WeiLayout,
InLayout,
OutDataType,
WeiDataType,
InDataType,
ComputeDataType>(do_verification,
init_method,
do_log,
time_kernel,
params,
split_k,
instance_index,
list_instances);
return pass ? 0 : 1;
};

View File

@@ -7,6 +7,7 @@
#include <numeric>
#include "profiler/profile_grouped_conv_bwd_weight_impl.hpp"
#include "profiler/profiler_arg_utils.hpp"
#include "profiler_operation_registry.hpp"
namespace {
@@ -61,6 +62,10 @@ static void print_helper_msg()
<< ck::utils::conv::get_conv_param_parser_helper_msg()
<< " SplitK (-1 for internally computed split-K value, positive value to set k "
"batches explicitly, or 'all' to test all internal split-K values)\n"
<< "\nOptional arguments:\n"
<< " --instance <id> Run only the specified instance (0-indexed among valid "
"instances)\n"
<< " --list-instances List all valid instances without running\n"
<< std::endl;
}
@@ -68,8 +73,17 @@ static void print_helper_msg()
int profile_grouped_conv_bwd_weight(int argc, char* argv[])
{
// Parse optional named arguments first
ck::index_t instance_index = -1;
bool list_instances = false;
ck::profiler::parse_named_args(argc, argv, instance_index, list_instances);
const int named_arg_count = ck::profiler::count_named_args(argc, argv);
// Adjust argc for positional argument checking
const int positional_argc = argc - named_arg_count;
// 8 for control, 1 for num_dim_spatial
if(argc < 9)
if(positional_argc < 9)
{
print_helper_msg();
return 1;
@@ -84,7 +98,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
const int num_dim_spatial = std::stoi(argv[8]);
// 8 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial, 1 for split-K
if(argc != 8 + 1 + 4 + 6 * num_dim_spatial + 1)
if(positional_argc != 8 + 1 + 4 + 6 * num_dim_spatial + 1)
{
print_helper_msg();
return 1;
@@ -129,16 +143,23 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
using ComputeTypeA = decltype(compute_type_a);
using ComputeTypeB = decltype(compute_type_b);
bool pass = ck::profiler::profile_grouped_conv_bwd_weight_impl<NDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
ComputeTypeA,
ComputeTypeB>(
do_verification, init_method, do_log, time_kernel, params, split_k);
bool pass =
ck::profiler::profile_grouped_conv_bwd_weight_impl<NDimSpatial,
InLayout,
WeiLayout,
OutLayout,
InDataType,
WeiDataType,
OutDataType,
ComputeTypeA,
ComputeTypeB>(do_verification,
init_method,
do_log,
time_kernel,
params,
split_k,
instance_index,
list_instances);
return pass ? 0 : 1;
};

View File

@@ -5,8 +5,8 @@
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include "profiler/profile_grouped_conv_fwd_impl.hpp"
#include "profiler/profiler_arg_utils.hpp"
#include "profiler_operation_registry.hpp"
namespace {
@@ -66,7 +66,10 @@ static void print_helper_msg()
<< "arg6: initialization (0: no init, 1: integer value, 2: decimal value)\n"
<< "arg7: print tensor value (0: no; 1: yes)\n"
<< "arg8: time kernel (0: no, 1: yes)\n"
<< ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl;
<< ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl
<< "\nOptional arguments:\n"
<< " --instance <id> Run only the specified instance (0-indexed among valid instances)\n"
<< " --list-instances List all valid instances without running\n";
// clang-format on
}
@@ -74,8 +77,17 @@ static void print_helper_msg()
int profile_grouped_conv_fwd(int argc, char* argv[])
{
// Parse optional named arguments first
ck::index_t instance_index = -1;
bool list_instances = false;
ck::profiler::parse_named_args(argc, argv, instance_index, list_instances);
const int named_arg_count = ck::profiler::count_named_args(argc, argv);
// Adjust argc for positional argument checking
const int positional_argc = argc - named_arg_count;
// 8 for control, 1 for num_dim_spatial
if(argc < 10)
if(positional_argc < 10)
{
print_helper_msg();
return 1;
@@ -91,7 +103,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
const int num_dim_spatial = std::stoi(argv[9]);
// 9 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial
if(argc != 9 + 1 + 4 + 6 * num_dim_spatial)
if(positional_argc != 9 + 1 + 4 + 6 * num_dim_spatial)
{
print_helper_msg();
return 1;
@@ -178,7 +190,14 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
AComputeType,
BComputeType,
ck::index_t>(
do_verification, init_method, do_log, time_kernel, params);
do_verification,
init_method,
do_log,
time_kernel,
params,
ck::tensor_operation::element_wise::PassThrough{},
instance_index,
list_instances);
return pass ? 0 : 1;
}
@@ -194,7 +213,14 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
AComputeType,
BComputeType,
ck::long_index_t>(
do_verification, init_method, do_log, time_kernel, params);
do_verification,
init_method,
do_log,
time_kernel,
params,
ck::tensor_operation::element_wise::PassThrough{},
instance_index,
list_instances);
return pass ? 0 : 1;
}

View File

@@ -128,6 +128,12 @@ def run_ck_grouped_conv_fwd(args):
cmd += [str(args.in_channels)]
add_conv_params_to_cmd(args, cmd)
# Add optional named arguments
if args.instance != -1:
cmd += ["--instance", str(args.instance)]
if args.list_instances:
cmd += ["--list-instances"]
run_ck_profiler_cmd(cmd)
@@ -148,6 +154,13 @@ def run_ck_grouped_conv_bwd_data(args):
add_conv_params_to_cmd(args, cmd)
cmd += [str(args.split_k_value)]
# Add optional named arguments
if args.instance != -1:
cmd += ["--instance", str(args.instance)]
if args.list_instances:
cmd += ["--list-instances"]
run_ck_profiler_cmd(cmd)
@@ -168,6 +181,13 @@ def run_ck_grouped_conv_bwd_weight(args):
add_conv_params_to_cmd(args, cmd)
cmd += [str(args.split_k_value)]
# Add optional named arguments
if args.instance != -1:
cmd += ["--instance", str(args.instance)]
if args.list_instances:
cmd += ["--list-instances"]
run_ck_profiler_cmd(cmd)
@@ -461,6 +481,22 @@ if __name__ == "__main__":
required=False,
help="Number of Groups (Default=1)",
)
parser.add_argument(
"-instance",
"--instance",
type=int,
default=-1,
required=False,
help="Instance index (Default=-1)",
)
parser.add_argument(
"-list-instances",
"--list-instances",
action="store_true",
default=False,
required=False,
help="List valid instances without running",
)
args, unknown = parser.parse_known_args()
init_const_args(args)