[rocm-libraries] ROCm/rocm-libraries#5114 (commit 59b8cb5)

[CK][CK Tile] Improvements for grouped conv fwd tile
 profiling (#5114)

## Motivation

Improve profiling for grouped convolution forward for better comparison
between CK and CK Tile
## Technical Details

- Include preprocessing time for ck tile
- Add flush cache for conv fwd profiler
- Switch configs to builder reflect
- Add KPerXdl deduce
- Add non-grouped ported instances

## Test Plan

test_grouped_convnd_fwd_tile

## Test Result

pass

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.

AICK-786
This commit is contained in:
Bartłomiej Kocot
2026-03-11 22:39:20 +00:00
committed by assistant-librarian[bot]
parent c1f2d8166d
commit 2169367735
24 changed files with 2375 additions and 1874 deletions

View File

@@ -65,7 +65,9 @@ run_grouped_conv_forward_tile_algs(const ckt::Args<SIGNATURE>& args,
const ckt::Outputs<SIGNATURE>& outputs,
const ck_tile::stream_config& s_conf)
{
float best_avg_time = std::numeric_limits<float>::max();
// Run first instance as dummy to get proper time from the first instance
bool dummy_run_executed = false;
float best_avg_time = std::numeric_limits<float>::max();
std::string best_op_name, op_name;
bool is_supported;
float avg_time;
@@ -84,6 +86,12 @@ run_grouped_conv_forward_tile_algs(const ckt::Args<SIGNATURE>& args,
auto ref_conv = ReferenceInstance{};
auto ref_result = ckt::run(ref_conv, args, inputs, reference.get());
auto run_alg = [&](auto&& run_alg_func) {
if(!dummy_run_executed)
{
// Run first instance twice
std::tie(is_supported, avg_time, op_name) = run_alg_func(args, inputs, outputs, s_conf);
dummy_run_executed = true;
}
std::tie(is_supported, avg_time, op_name) = run_alg_func(args, inputs, outputs, s_conf);
if(is_supported)
{

View File

@@ -65,7 +65,19 @@ void print_instances()
for(const auto& op_ptr : op_ptrs)
{
#ifdef CK_EXPERIMENTAL_BUILDER
const auto& instance_str = op_ptr->GetInstanceString();
if(!instance_str.empty())
{
std::cout << instance_str << std::endl;
}
else
{
std::cout << op_ptr->GetTypeString() << std::endl;
}
#else
std::cout << op_ptr->GetTypeString() << std::endl;
#endif
}
}
} // namespace fwd
@@ -298,8 +310,13 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
auto invoker_ptr = op_ptr->MakeInvokerPointer();
float avg_time =
invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel});
float avg_time = invoker_ptr->Run(argument_ptr.get(),
StreamConfig{nullptr,
time_kernel,
0 /*log_level*/,
5 /*cold_iters*/,
50 /*nrepeat_*/,
time_kernel /*flush_cache*/});
std::size_t flop = conv_param.GetFlops();
std::size_t num_btype = conv_param.GetByte<InDataType, WeiDataType, OutDataType>();
@@ -420,6 +437,30 @@ bool profile_grouped_conv_fwd_impl(int do_verification,
std::cout << "\nValid instances for this problem:" << std::endl;
}
// Run first instance twice to get proper time
{
auto argument_ptr = op_ptrs[0]->MakeArgumentPointer(in_device_buf.GetDeviceBuffer(),
wei_device_buf.GetDeviceBuffer(),
{},
out_device_buf.GetDeviceBuffer(),
a_g_n_c_wis_lengths,
a_g_n_c_wis_strides,
b_g_k_c_xs_lengths,
b_g_k_c_xs_strides,
{},
{},
e_g_n_k_wos_lengths,
e_g_n_k_wos_strides,
conv_filter_strides,
conv_filter_dilations,
input_left_pads,
input_right_pads,
in_element_op,
wei_element_op,
out_element_op);
run_impl(op_ptrs[0], argument_ptr);
}
for(auto& op_ptr : op_ptrs)
{
auto argument_ptr = op_ptr->MakeArgumentPointer(in_device_buf.GetDeviceBuffer(),

View File

@@ -97,7 +97,16 @@ int call_profiler(const ckt::Args<SIGNATURE>& args, bool time_kernel)
std::string op_name;
bool valid;
std::tie(valid, avg_time, op_name) = ckp::run_grouped_conv_forward_tile_algs(
args, inputs.get(), outputs.get(), ck_tile::stream_config{nullptr, time_kernel, 0, 5, 50});
args,
inputs.get(),
outputs.get(),
ck_tile::stream_config{nullptr,
time_kernel,
0 /*log_level*/,
5 /*cold_iters*/,
50 /*nrepeat_*/,
true /*is_gpu_timer_*/,
time_kernel /*flush_cache*/});
if(time_kernel)
{
std::cout << "Best configuration parameters:" << "\nname: " << op_name