[rocm-libraries] ROCm/rocm-libraries#6327 (commit 1e7a12e)

[CK][CK TILE] Dispatcher kernel selection heuristic for
 grouped conv (#6327)

## Motivation
The ML heuristic in dispatcher does not support grouped-conv operator
yet. In this PR, the support for fwd, bdw-data, and bwd-weight
grouped-conv kernels have been added. A tile_engine utility has also
been added to compile and run any selected kernel configuration through
dispatcher infrastructure.

## Technical Details

1. Tile engine utility is added to benchmark each shape with all the
possible kernel+tile_size combinations here -
[https://github.com/ROCm/rocm-libraries/blob/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/tile_engine/ops/grouped_conv/grouped_conv_full_benchmark.py](url)
2. New LGBM regressor models for grouped conv are added to models
directory. We have 3 separate models for fwd, bwd-data, and bwd-weights
[https://github.com/ROCm/rocm-libraries/tree/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/dispatcher/heuristics/models](url)
3. Implemented lazy GPU initialization (dispatcher/python)
- **Issue**: ProcessPoolExecutor fork() + GPU context caused memory
access faults
- **Solution**: Mirror FMHA pattern - defer GPU initialization until
first run()
  - **Changes**:
- setup_multiple_grouped_conv_dispatchers() returns List[Path], not
loaded libs
    - GpuGroupedConvRunner.__init__() no longer calls ctypes.CDLL
    - Added _ensure_initialized() method for lazy GPU loading
    - GPU context created only on first run() call
  - **Benefit**: Parallel compilation now works without GPU conflicts
4. Addressed few miscellaneous issues such as:
  - Fixed BF16->FP16 naming bug in the dispatcher wrapper
- Added new tile sizes, and comp_v5 pipeline to the arch spec to expand
the kernel selection
- Added automatic padding support for unsupported shapes in dispatcher
runner
- Created a single source of truth between tile_engine and dispatcher
about the architecture and tile_size details
- Build a validation scripts to compare oracle_best vs ml_heuristic
comparison

## Test Plan

1. Validated fwd, bwd-data, and bwd-weight kernels with both known and
unseen data sets with up to 300 problems.
2. Ensured that test cases are added in both dispatcher and tile_engine
to validate the heuristic.

## Test Result
Results on Unseen shapes validated on gfx950
#### Forward Pass Model
- **Training Data**: 48,845 measurements across 1,372 unique problem
shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
  - Mean Efficiency: **93.05%**
  - Median Efficiency: **96.8%**
  - P10 Efficiency: **79.9%**

#### Backward Data Gradient (bwd_data) Model
- **Training Data**: 18,773 measurements across 891 unique problem
shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
  - Mean Efficiency: **93.8%**
  - Median Efficiency: **96.5%**
  - P10 Efficiency: **82.9%**

#### Backward Weight Gradient (bwd_weight) Model
- **Training Data**: 34,900 measurements across 1,508 unique problem
shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
  - Mean Efficiency: **96.1%**
  - Median Efficiency: **99.2%**
  - P10 Efficiency: **89.4%**

## Submission Checklist

- [ x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
Yaswanth Raparti
2026-05-08 20:48:42 +00:00
committed by assistant-librarian[bot]
parent b05040b919
commit 6989cf800c
65 changed files with 13206 additions and 389 deletions

View File

@@ -41,6 +41,26 @@ except ImportError:
ArchFilter = None
OperatorType = None
# Import tile configurations from grouped_config_rules (single source of truth)
try:
from grouped_config_rules import (
COMMON_TILES,
TILE_TO_WAVE,
TILE_TO_WARP,
VARIANT_PIPELINES,
BWD_WEIGHT_TILES,
COMPV4_COMPATIBLE_TILES,
)
HAS_TILE_CONFIGS = True
except ImportError:
HAS_TILE_CONFIGS = False
COMMON_TILES = []
TILE_TO_WAVE = {}
TILE_TO_WARP = {}
VARIANT_PIPELINES = {}
BWD_WEIGHT_TILES = []
COMPV4_COMPATIBLE_TILES = []
# ============================================================================
# Configuration and Data Structures
@@ -494,6 +514,21 @@ struct {kernel_name}_Config {{
# Create valid C++ namespace name
ns_name = "ns_" + kernel_name.replace("-", "_")
# basic_v1 / basic_async_v1 inherit BaseGemmPipelineAGmemBGmemCRegV1
# whose TailHandler takes (run_func, has_hot_loop) and invokes
# run_func(bool_constant<...>) -- 1 lambda arg. Other pipelines pass
# (run_func, has_hot_loop, tail_number) and invoke 2-arg run_func.
if tr.pipeline in ("basic_v1", "basic_async_v1"):
tail_handler_call = "BaseGemmPipeline::TailHandler(Run, has_hot_loop);"
run_lambda_signature = "[&](const auto has_hot_loop_)"
else:
tail_handler_call = (
"BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);"
)
run_lambda_signature = (
"[&](const auto has_hot_loop_, const auto tail_number_)"
)
return f"""
// Unique namespace for this kernel to avoid conflicts when including multiple kernels
namespace {ns_name} {{
@@ -605,7 +640,7 @@ struct {kernel_name}_Launcher {{
using Kernel = {kernel_type}<
GroupedConvTraitsType, TilePartitioner, GemmPipeline, ConvEpilogue>;
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {{
const auto Run = {run_lambda_signature} {{
auto kargs = Kernel::MakeKernelArgs(args);
if (!Kernel::IsSupportedArgument(kargs)) {{
@@ -621,7 +656,7 @@ struct {kernel_name}_Launcher {{
return ave_time;
}};
BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
{tail_handler_call}
return ave_time;
}}
}};
@@ -1021,7 +1056,10 @@ def get_default_configs(
variants: Optional[List[GroupedConvVariant]] = None,
ndims: Optional[List[int]] = None,
) -> List[GroupedConvKernelConfig]:
"""Get default grouped convolution configurations for target architecture"""
"""Get default grouped convolution configurations for target architecture.
Uses tile configurations from grouped_conv_instance_builder.py as single source of truth.
"""
configs = []
if variants is None:
@@ -1029,39 +1067,53 @@ def get_default_configs(
if ndims is None:
ndims = [2]
# Valid configurations per variant (based on CK Tile example configs)
# Forward and Backward Data: standard GEMM-like tiles
fwd_bwd_data_tiles = [
# (tile_m, tile_n, tile_k, warp_m, warp_n, warp_tile_m, warp_tile_n, warp_tile_k)
(128, 128, 32, 2, 2, 32, 32, 16), # Standard 128x128
(256, 256, 32, 2, 2, 32, 32, 16), # Large 256x256
(64, 64, 32, 1, 4, 16, 16, 16), # Small 64x64
(128, 64, 32, 2, 2, 32, 32, 16), # Rectangular
(16, 64, 64, 1, 4, 16, 16, 32), # Tall and narrow
]
# Import tile configs from instance builder (single source of truth)
if not HAS_TILE_CONFIGS or not COMMON_TILES:
log.warning("grouped_config_rules not available, using fallback tile configs")
# Fallback to minimal set if grouped_config_rules unavailable
fwd_bwd_data_tiles = [
(128, 128, 32, 2, 2, 32, 32, 16),
(64, 64, 32, 1, 4, 16, 16, 16),
(16, 64, 64, 1, 4, 16, 16, 32),
]
bwd_weight_tiles = [(16, 64, 64, 1, 4, 16, 16, 32)]
else:
# Build tile list from COMMON_TILES with wave/warp mappings
fwd_bwd_data_tiles = []
for tile_m, tile_n, tile_k in COMMON_TILES:
tile_key = (tile_m, tile_n, tile_k)
if tile_key in TILE_TO_WAVE and tile_key in TILE_TO_WARP:
wave_m, wave_n, wave_k = TILE_TO_WAVE[tile_key]
warp_m, warp_n, warp_k = TILE_TO_WARP[tile_key]
fwd_bwd_data_tiles.append(
(tile_m, tile_n, tile_k, wave_m, wave_n, warp_m, warp_n, warp_k)
)
# Backward Weight: VERY specific tile configs that work with CK Tile's bwd_weight kernel
# Based on ConvConfigComputeV3 from CK Tile examples (example/ck_tile/20_grouped_convolution/)
# Note: Backward weight has strict constraints on warp configurations due to transpose_tile2d
# Only specific warp configs work: (1, 4, 1) and (4, 1, 1) are known to work
bwd_weight_tiles = [
# (tile_m, tile_n, tile_k, warp_m, warp_n, warp_tile_m, warp_tile_n, warp_tile_k)
# ConvConfigComputeV3: The primary working config for backward weight
(16, 64, 64, 1, 4, 16, 16, 32),
]
# Backward weight: use BWD_WEIGHT_TILES from config rules
bwd_weight_tiles = []
for tile_m, tile_n, tile_k in BWD_WEIGHT_TILES:
tile_key = (tile_m, tile_n, tile_k)
if tile_key in TILE_TO_WAVE and tile_key in TILE_TO_WARP:
wave_m, wave_n, wave_k = TILE_TO_WAVE[tile_key]
warp_m, warp_n, warp_k = TILE_TO_WARP[tile_key]
bwd_weight_tiles.append(
(tile_m, tile_n, tile_k, wave_m, wave_n, warp_m, warp_n, warp_k)
)
for variant in variants:
# Select tile configs based on variant
if variant == GroupedConvVariant.BACKWARD_WEIGHT:
tile_configs = bwd_weight_tiles
# Backward weight ONLY supports compv3 (compv4/compv5 have transpose_tile2d issues)
pipelines = [("compv3", "cshuffle")]
# Backward weight supports compv3 and mem pipelines
# (compv4/compv5 have transpose_tile2d issues)
pipelines = [("compv3", "cshuffle"), ("mem", "default")]
# Also generate two-stage variants (fp32 workspace + elementwise convert)
two_stage_flags = [False, True]
elif variant == GroupedConvVariant.BACKWARD_DATA:
tile_configs = fwd_bwd_data_tiles
# Backward data ONLY supports compv3 (compv4 has get_length issues in bwd_data kernel)
pipelines = [("compv3", "cshuffle")]
# Backward data supports compv3 and mem pipelines
# (compv4/compv5 have get_length issues in bwd_data kernel)
pipelines = [("compv3", "cshuffle"), ("mem", "default")]
two_stage_flags = [False]
else:
tile_configs = fwd_bwd_data_tiles
@@ -1080,6 +1132,12 @@ def get_default_configs(
warp_tile_n,
warp_tile_k,
) in tile_configs:
# Skip tiles incompatible with compv4
if pipeline == "compv4" and HAS_TILE_CONFIGS:
tile_key = (tile_m, tile_n, tile_k)
if tile_key not in COMPV4_COMPATIBLE_TILES:
continue # Skip this tile for compv4
for two_stage in two_stage_flags:
adj_tile_k = tile_k * 2 if pipeline == "compv4" else tile_k
@@ -1609,7 +1667,16 @@ def main():
parser.add_argument(
"--pipeline",
type=str,
choices=["mem", "compv3", "compv4", "compv5"],
choices=[
"basic_v1",
"basic_async_v1",
"mem",
"compv3",
"compv4",
"compv5",
"compv6",
"comp_async",
],
help="Pipeline type",
)
parser.add_argument(
@@ -1642,6 +1709,16 @@ def main():
default=None,
help="Double SMEM buffer (true/false)",
)
parser.add_argument(
"--split-image",
action="store_true",
help="Enable split-image (EnableSplitImage) for large spatial tensors",
)
parser.add_argument(
"--two-stage",
action="store_true",
help="Enable two-stage bwd_weight (fp32 workspace + elementwise convert)",
)
args = parser.parse_args()
@@ -1679,7 +1756,13 @@ def main():
if args.double_smem_buffer is not None:
dsb = args.double_smem_buffer.lower() == "true"
else:
dsb = pipeline == "compv4" # compv4 requires double buffer
# Historical default: only compv4 auto-defaults to dsb=true.
# Other pipelines that also require DoubleSmemBuffer (e.g. comp_async)
# must be told explicitly via --double-smem-buffer true; otherwise
# they will fail loudly at the pipeline header static_assert. This
# is intentional -- silent fallback to a different config would
# mask the user's input.
dsb = pipeline == "compv4"
trait = GroupedConvTraitConfig(
pipeline=pipeline,
@@ -1690,6 +1773,8 @@ def main():
pad_k=args.pad_k,
double_smem_buffer=dsb,
num_groups_to_merge=args.num_groups_to_merge,
split_image=args.split_image,
two_stage=args.two_stage,
)
config = GroupedConvKernelConfig(
tile=tile,
@@ -1719,18 +1804,20 @@ def main():
print(f" Spatial dims: {args.ndim}")
print(f"\nConfigurations ({len(filtered_configs)}):")
for cfg in filtered_configs:
print(f" - {cfg.name('fp16')}")
print(f" Tile: {cfg.tile.tile_m}x{cfg.tile.tile_n}x{cfg.tile.tile_k}")
print(f" Warp: {cfg.tile.warp_m}x{cfg.tile.warp_n}x{cfg.tile.warp_k}")
print(
f" WarpTile: {cfg.tile.warp_tile_m}x{cfg.tile.warp_tile_n}x{cfg.tile.warp_tile_k}"
)
print(
f" Pipeline: {cfg.trait.pipeline}, Epilogue: {cfg.trait.epilogue}, Scheduler: {cfg.trait.scheduler}"
)
print(
f" Padding: M={cfg.trait.pad_m}, N={cfg.trait.pad_n}, K={cfg.trait.pad_k}"
)
# List configs for each requested datatype (fixes bf16 -> fp16 bug)
for dt in args.datatype:
print(f" - {cfg.name(dt)}")
print(f" Tile: {cfg.tile.tile_m}x{cfg.tile.tile_n}x{cfg.tile.tile_k}")
print(f" Warp: {cfg.tile.warp_m}x{cfg.tile.warp_n}x{cfg.tile.warp_k}")
print(
f" WarpTile: {cfg.tile.warp_tile_m}x{cfg.tile.warp_tile_n}x{cfg.tile.warp_tile_k}"
)
print(
f" Pipeline: {cfg.trait.pipeline}, Epilogue: {cfg.trait.epilogue}, Scheduler: {cfg.trait.scheduler}"
)
print(
f" Padding: M={cfg.trait.pad_m}, N={cfg.trait.pad_n}, K={cfg.trait.pad_k}"
)
return
# Generate