[CK][CK Tile] Improvements for grouped conv fwd tile profiling (#5114)

## Motivation

Improve profiling for grouped convolution forward for better comparison
between CK and CK Tile
## Technical Details

- Include preprocessing time for ck tile
- Add flush cache for conv fwd profiler
- Switch configs to builder reflect
- Add KPerXdl deduce
- Add non-grouped ported instances

## Test Plan

test_grouped_convnd_fwd_tile

## Test Result

pass

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.

AICK-786
This commit is contained in:
Bartłomiej Kocot
2026-03-11 23:38:15 +01:00
committed by GitHub
parent 622122155a
commit 1972d39410
24 changed files with 2375 additions and 1874 deletions

View File

@@ -9,6 +9,7 @@
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/grouped_convolution.hpp"
#include "ck_tile/host.hpp"
#include <type_traits>
#include <array>
@@ -59,7 +60,8 @@ template <auto SIGNATURE, typename InDataType, typename WeiDataType, typename Ou
std::end(kargs.wei_g_k_c_xs_lengths.data),
1,
std::multiplies<std::size_t>());
auto preprocess = [&]() {
auto preprocess = [&]() {
if constexpr(ConvDirectionIsBackwardWeight<SIGNATURE>)
{
if(args.k_batch > 1)
@@ -73,10 +75,20 @@ template <auto SIGNATURE, typename InDataType, typename WeiDataType, typename Ou
constexpr index_t minimum_occupancy =
Conv::GemmPipeline::Scheduler == ck_tile::GemmPipelineScheduler::Intrawave ? 1 : 2;
return RunResult::from_runtime(ck_tile::launch_kernel_time_mask(
s_conf,
preprocess,
ck_tile::make_kernel<minimum_occupancy>(conv, grids, blocks, 0, kargs)));
if(s_conf.flush_cache_)
{
return RunResult::from_runtime(ck_tile::launch_kernel_time_mask_flush_cache(
s_conf,
preprocess,
ck_tile::make_kernel<minimum_occupancy>(conv, grids, blocks, 0, kargs)));
}
else
{
return RunResult::from_runtime(ck_tile::launch_kernel_time_mask(
s_conf,
preprocess,
ck_tile::make_kernel<minimum_occupancy>(conv, grids, blocks, 0, kargs)));
}
}
} // namespace detail