Files
composable_kernel/dispatcher/examples/grouped_conv/python/11_test_schedulers.py
Yaswanth Raparti 6989cf800c [rocm-libraries] ROCm/rocm-libraries#6327 (commit 1e7a12e)
[CK][CK TILE] Dispatcher kernel selection heuristic for
 grouped conv (#6327)

## Motivation
The ML heuristic in dispatcher does not support grouped-conv operator
yet. In this PR, the support for fwd, bdw-data, and bwd-weight
grouped-conv kernels have been added. A tile_engine utility has also
been added to compile and run any selected kernel configuration through
dispatcher infrastructure.

## Technical Details

1. Tile engine utility is added to benchmark each shape with all the
possible kernel+tile_size combinations here -
[https://github.com/ROCm/rocm-libraries/blob/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/tile_engine/ops/grouped_conv/grouped_conv_full_benchmark.py](url)
2. New LGBM regressor models for grouped conv are added to models
directory. We have 3 separate models for fwd, bwd-data, and bwd-weights
[https://github.com/ROCm/rocm-libraries/tree/users/yraparti/ck/dispatcher-grouped-conv-heuristics/projects/composablekernel/dispatcher/heuristics/models](url)
3. Implemented lazy GPU initialization (dispatcher/python)
- **Issue**: ProcessPoolExecutor fork() + GPU context caused memory
access faults
- **Solution**: Mirror FMHA pattern - defer GPU initialization until
first run()
  - **Changes**:
- setup_multiple_grouped_conv_dispatchers() returns List[Path], not
loaded libs
    - GpuGroupedConvRunner.__init__() no longer calls ctypes.CDLL
    - Added _ensure_initialized() method for lazy GPU loading
    - GPU context created only on first run() call
  - **Benefit**: Parallel compilation now works without GPU conflicts
4. Addressed few miscellaneous issues such as:
  - Fixed BF16->FP16 naming bug in the dispatcher wrapper
- Added new tile sizes, and comp_v5 pipeline to the arch spec to expand
the kernel selection
- Added automatic padding support for unsupported shapes in dispatcher
runner
- Created a single source of truth between tile_engine and dispatcher
about the architecture and tile_size details
- Build a validation scripts to compare oracle_best vs ml_heuristic
comparison

## Test Plan

1. Validated fwd, bwd-data, and bwd-weight kernels with both known and
unseen data sets with up to 300 problems.
2. Ensured that test cases are added in both dispatcher and tile_engine
to validate the heuristic.

## Test Result
Results on Unseen shapes validated on gfx950
#### Forward Pass Model
- **Training Data**: 48,845 measurements across 1,372 unique problem
shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
  - Mean Efficiency: **93.05%**
  - Median Efficiency: **96.8%**
  - P10 Efficiency: **79.9%**

#### Backward Data Gradient (bwd_data) Model
- **Training Data**: 18,773 measurements across 891 unique problem
shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
  - Mean Efficiency: **93.8%**
  - Median Efficiency: **96.5%**
  - P10 Efficiency: **82.9%**

#### Backward Weight Gradient (bwd_weight) Model
- **Training Data**: 34,900 measurements across 1,508 unique problem
shapes
- **Validation Set**: 300 unseen problems from model crawler
- **Validation Performance** (vs. oracle):
  - Mean Efficiency: **96.1%**
  - Median Efficiency: **99.2%**
  - P10 Efficiency: **89.4%**

## Submission Checklist

- [ x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-05-08 20:48:42 +00:00

402 lines
13 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 11: Test All Pipeline + Scheduler Combinations
Tests all 8 pipelines with both intrawave and interwave schedulers
for all convolution variants to determine which combinations work.
Usage:
python3 11_test_schedulers.py
python3 11_test_schedulers.py --arch gfx942
python3 11_test_schedulers.py --variant forward
"""
import sys
import argparse
import time
import numpy as np
from pathlib import Path
import json
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
from grouped_conv_utils import (
GroupedConvKernelConfig,
GroupedConvProblem,
GroupedConvRegistry,
detect_gpu_arch,
)
# All pipelines from unified_grouped_conv_codegen.py
ALL_PIPELINES = [
"basic_v1",
"mem",
"compv3",
"compv4",
"compv5",
"compv6",
"comp_async",
"basic_async_v1",
]
# Both schedulers
ALL_SCHEDULERS = ["intrawave", "interwave"]
# Pipelines that require DoubleSmemBuffer=true (enforced by static_assert in
# the pipeline headers). Building these with dsb=false is a loud compile error.
PIPELINES_REQUIRING_DSB = {"compv4", "comp_async"}
def test_pipeline_scheduler(pipeline, scheduler, variant, arch, dtype, ndim=2):
"""
Test if a pipeline+scheduler+variant combination builds and runs successfully.
Args:
pipeline: Pipeline name (e.g., "compv3", "mem")
scheduler: Scheduler type ("intrawave" or "interwave")
variant: Convolution variant (forward, bwd_data, bwd_weight)
arch: GPU architecture (e.g., "gfx950")
dtype: Data type (fp16, bf16)
ndim: Spatial dimensions (2 or 3)
Returns:
dict with keys: pipeline, scheduler, variant, ndim, build_success, run_success, error_msg
"""
result = {
"pipeline": pipeline,
"scheduler": scheduler,
"variant": variant,
"ndim": ndim,
"arch": arch,
"dtype": dtype,
"build_success": False,
"run_success": False,
"error_msg": None,
"time_ms": None,
"tflops": None,
}
try:
# Create registry with single kernel config
reg = GroupedConvRegistry(f"{variant}_{pipeline}_{scheduler}_{ndim}d")
# Use a simple, safe tile config: 16x64x64
# wave 1x4x1, warp 16x16x16
config = GroupedConvKernelConfig(
variant=variant,
ndim_spatial=ndim,
arch=arch,
dtype=dtype,
tile_m=16,
tile_n=64,
tile_k=64,
wave_m=1,
wave_n=4,
wave_k=1,
warp_tile_m=16,
warp_tile_n=16,
warp_tile_k=16,
pipeline=pipeline,
scheduler=scheduler, # Test scheduler here
epilogue="cshuffle" if pipeline not in ["mem"] else "default",
vector_size_a=4,
vector_size_b=8,
vector_size_c=8,
block_per_cu=1,
# compv4/comp_async require DoubleSmemBuffer=true (loud
# static_assert otherwise); other pipelines do not.
double_smem_buffer=(pipeline in PIPELINES_REQUIRING_DSB),
)
reg.add(config)
# Try to build
try:
runners = reg.build(verbose=False, max_workers=1)
key = (variant, ndim)
if key in runners:
result["build_success"] = True
# Try to run
np_dtype = np.float16 if dtype in ["fp16", "bf16"] else np.float32
if ndim == 2:
prob = GroupedConvProblem(
N=1,
C=64,
K=64,
Hi=8,
Wi=8,
Y=3,
X=3,
pad_h=1,
pad_w=1,
direction=variant,
)
else: # 3D
prob = GroupedConvProblem(
N=1,
C=64,
K=64,
Di=4,
Hi=8,
Wi=8,
Z=3,
Y=3,
X=3,
pad_d=1,
pad_h=1,
pad_w=1,
direction=variant,
)
# Generate inputs
if variant == "forward":
x = np.random.uniform(-0.5, 0.5, prob.input_shape()).astype(
np_dtype
)
w = np.random.uniform(-0.5, 0.5, prob.weight_shape()).astype(
np_dtype
)
res = runners[key].run(x, w, prob)
elif variant == "bwd_data":
# Runner contract: input_np=dY, weight_np=W for bwd_data
w = np.random.uniform(-0.5, 0.5, prob.weight_shape()).astype(
np_dtype
)
dy = np.random.uniform(-0.5, 0.5, prob.output_shape()).astype(
np_dtype
)
res = runners[key].run(dy, w, prob)
elif variant == "bwd_weight":
x = np.random.uniform(-0.5, 0.5, prob.input_shape()).astype(
np_dtype
)
dy = np.random.uniform(-0.5, 0.5, prob.output_shape()).astype(
np_dtype
)
res = runners[key].run(x, dy, prob)
if res.success and np.count_nonzero(res.output) > 0:
result["run_success"] = True
result["time_ms"] = res.time_ms
result["tflops"] = res.tflops
else:
result["error_msg"] = "Kernel ran but produced zero output"
# Cleanup
runners[key].cleanup()
else:
result["error_msg"] = "Kernel not in runners (build failed)"
except Exception as e:
result["error_msg"] = f"Build exception: {str(e)}"
except Exception as e:
result["error_msg"] = f"Setup exception: {str(e)}"
return result
def main():
parser = argparse.ArgumentParser(
description="Test All Pipeline + Scheduler Combinations"
)
parser.add_argument("--arch", default=detect_gpu_arch())
parser.add_argument("--dtype", default="bf16", choices=["fp16", "bf16"])
parser.add_argument(
"--variant",
default="all",
choices=["all", "forward", "bwd_data", "bwd_weight"],
help="Variant to test (default: all)",
)
parser.add_argument(
"--ndim",
type=int,
default=2,
choices=[2, 3],
help="Spatial dimensions to test (default: 2)",
)
parser.add_argument(
"--scheduler",
default="all",
choices=["all", "intrawave", "interwave"],
help="Scheduler to test (default: all)",
)
parser.add_argument(
"--output",
default="scheduler_test_results.json",
help="Output JSON file (default: scheduler_test_results.json)",
)
args = parser.parse_args()
arch = args.arch
print("=" * 80)
print("Test All Pipeline + Scheduler Combinations")
print("=" * 80)
print(f"Arch: {arch}, Dtype: {args.dtype}, NDim: {args.ndim}D")
print()
# Determine variants to test
if args.variant == "all":
variants = ["forward", "bwd_data", "bwd_weight"]
else:
variants = [args.variant]
# Determine schedulers to test
if args.scheduler == "all":
schedulers = ALL_SCHEDULERS
else:
schedulers = [args.scheduler]
# Run tests
all_results = []
for variant in variants:
print(f"\n{'=' * 80}")
print(f"Testing {variant.upper()} ({args.ndim}D)")
print(f"{'=' * 80}")
print()
print(
f"{'Pipeline':<20} {'Scheduler':<12} {'Build':<8} {'Run':<8} {'Time (ms)':<12} {'TFLOPS':<10}"
)
print("-" * 80)
for pipeline in ALL_PIPELINES:
for scheduler in schedulers:
result = test_pipeline_scheduler(
pipeline, scheduler, variant, arch, args.dtype, args.ndim
)
all_results.append(result)
build_status = "" if result["build_success"] else ""
run_status = "" if result["run_success"] else ""
time_str = (
f"{result['time_ms']:.4f}"
if result["time_ms"] is not None
else "-"
)
tflops_str = (
f"{result['tflops']:.2f}" if result["tflops"] is not None else "-"
)
print(
f"{pipeline:<20} {scheduler:<12} {build_status:<8} {run_status:<8} {time_str:<12} {tflops_str:<10}"
)
if result["error_msg"] and not result["run_success"]:
print(f"{result['error_msg']}")
print()
# Summarize results by scheduler
print("=" * 80)
print("SUMMARY BY SCHEDULER")
print("=" * 80)
print()
for scheduler in schedulers:
print(f"\n{scheduler.upper()} Scheduler:")
print("-" * 80)
for variant in variants:
variant_results = [
r
for r in all_results
if r["variant"] == variant and r["scheduler"] == scheduler
]
successful_build = [
r["pipeline"] for r in variant_results if r["build_success"]
]
successful_run = [r["pipeline"] for r in variant_results if r["run_success"]]
print(f"\n{variant} ({args.ndim}D):")
print(f" Build success ({len(successful_build)}/8): {successful_build}")
print(f" Run success ({len(successful_run)}/8): {successful_run}")
# Overall summary
print("\n" + "=" * 80)
print("OVERALL SUMMARY")
print("=" * 80)
print()
# Per-pipeline support: a pipeline is "supported" if at least one
# scheduler runs successfully. Not every pipeline supports both
# intrawave and interwave (loud static_assert / unsupported trait
# in some pipeline headers), so we only require one to work.
per_variant_supported: dict[str, list[str]] = {}
for variant in variants:
print(f"{variant.upper()}:")
# Group by pipeline; mark as supported if any scheduler succeeded
supported_pipelines = []
per_pipeline_status = []
for pipeline in ALL_PIPELINES:
schedulers_ok = [
r["scheduler"]
for r in all_results
if r["variant"] == variant
and r["pipeline"] == pipeline
and r["run_success"]
]
if schedulers_ok:
supported_pipelines.append(pipeline)
per_pipeline_status.append((pipeline, "", schedulers_ok))
else:
per_pipeline_status.append((pipeline, "", []))
# Per-pipeline detail (any-scheduler-counts)
for pipeline, status, sched_list in per_pipeline_status:
sched_str = ",".join(sched_list) if sched_list else "none"
print(f" {pipeline:<18}: {status} via [{sched_str}]")
# Per-scheduler raw breakdown (for completeness)
for scheduler in schedulers:
variant_results = [
r
for r in all_results
if r["variant"] == variant and r["scheduler"] == scheduler
]
success_count = len([r for r in variant_results if r["run_success"]])
total = len(variant_results)
pct = (success_count / total * 100) if total > 0 else 0
print(
f" raw {scheduler:<10}: {success_count}/{total} ({pct:.0f}%) pipelines work"
)
# Any-scheduler aggregate
n_sup = len(supported_pipelines)
n_total = len(ALL_PIPELINES)
agg_pct = (n_sup / n_total * 100) if n_total > 0 else 0
agg_status = "" if n_sup > 0 else ""
print(
f" ANY scheduler : {agg_status} {n_sup}/{n_total} ({agg_pct:.0f}%) pipelines supported"
)
per_variant_supported[variant] = supported_pipelines
print()
# Save results
output_file = Path(__file__).parent / args.output
with open(output_file, "w") as f:
json.dump(all_results, f, indent=2)
print(f"Detailed results saved to: {output_file}")
print()
# Success criterion (relaxed): for each variant, at least one pipeline
# must be supported by at least one scheduler. Pipelines that fail under
# *both* schedulers are reported but don't fail the run, since some
# pipelines genuinely don't support both schedulers.
success = all(per_variant_supported.get(v) for v in variants)
return 0 if success else 1
if __name__ == "__main__":
sys.exit(main())