Files
composable_kernel/dispatcher/codegen/grouped_config_rules.py
Ville Pietilä 78d657c4f7 [rocm-libraries] ROCm/rocm-libraries#7284 (commit e7d25b2)
[CK_TILE] Integrate CK Tile Dispatcher code generation into
 CK Tile Profiler (#7284)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

## Motivation

CK Tile is going to be delivered to hipDNN via CK Dispatcher. Currently
the CK Tile Profiler using CK Builder for generating the profiled
instances from the configuration files that identify the instances that
old CK exposes. We need to replace this instance generation with the CK
Tile Dispatcher codegen.

## Technical Details
The old CK Profiler config files are converted to JSON files that the CK
Tile Dispatcher can digest. The conversion script for configurations is
stored to source control in case we need to update the JSON
configurations later. The dispatcher generates instance libraries per
conv direction (fwd, bwd data, and bwd weight) that are linked to the CK
Profiler executable. I also implemented codegne for the stream-K and
depthwise conv instances. The proposed solution replaces the CK Builder
codegen with the CK Tile Dispatcher codegen.

There are two new methods that are exposed via the dispatcher backend

- `is_supported` - required to enabled the profiler workflow where we
check the applicability of the kernel instance before running it.
- `get_instance_string` - this mainly for verification. This provide the
CK Builder instance string for verifying that the old CK Builder based
profiler and the new CK Tile Dispatcher based profiler have the same
instances.

The rules that limit the generated instances are now collected to a
single location under the dispacther. The CK Builder codegen uses these,
which ensures that the two codegen pipelines are in sync. The next step
(different PR) is to remove the CK Builder codegen pipeline altogether.

## Test Plan

Verified that the old CK Builder based profiler and the new CK Tile
Dispatcher based profiler have the same instances, that is, the
Dispatcher based codgen can generate the same instances as the old CK
Builder.

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-05-28 21:03:37 +00:00

410 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Single Source of Truth for Grouped Convolution Tile Configurations
This module defines all valid tile configurations for grouped convolution kernels.
Both codegen and instance_builder import from here to ensure consistency.
Architecture:
grouped_conv_tile_configs.py (SOURCE OF TRUTH)
├── Used by unified_grouped_conv_codegen.py
└── Used by grouped_conv_instance_builder.py
"""
from typing import Dict, List, Tuple
# =============================================================================
# Tile Configurations (Single Source of Truth)
# =============================================================================
# Common tile configurations used across variants
# Format: (tile_m, tile_n, tile_k)
# CRITICAL: tile_m MUST equal wave_m × warp_tile_m (TileGemmShape constraint)
# Only tiles that successfully compile are included
COMMON_TILES: List[Tuple[int, int, int]] = [
# Using warp_tile [16,16,16]: tile_m = wave_m × 16
(16, 64, 64), # 1 × 16 = 16, wave=(1,4,1)
(32, 64, 64), # 2 × 16 = 32, wave=(2,2,1)
(64, 64, 64), # 4 × 16 = 64, wave=(4,1,1)
# (128, 64, 64), # 8 × 16 = 128, wave=(8,2,1) - EXCLUDED: Compile error
# Using warp_tile [32,32,16]: tile_m = wave_m × 32
(32, 128, 64), # 1 × 32 = 32, wave=(1,4,1)
(64, 128, 64), # 2 × 32 = 64, wave=(2,2,1)
(128, 128, 64), # 4 × 32 = 128, wave=(4,4,1) - NEW!
# Note: 256x64x64 excluded - compilation issues
# Using warp_tile [16,16,32]: tile_m = wave_m × 16
(16, 64, 128), # 1 × 16 = 16, wave=(1,4,1)
(32, 64, 128), # 2 × 16 = 32, wave=(2,2,1)
(64, 64, 128), # 4 × 16 = 64, wave=(4,1,1)
(128, 64, 128), # 8 × 16 = 128, wave=(8,2,1) - NEW!
# Note: Excluded tiles:
# - 128x64x64: wave=8x2x1, warp=16x16x16 - compile error
# - 32x128x128, 64x128x128, 128x128x128, 256x128x128 (warp_tile 32x32x32) - compv4 issues
# - 256x64x64, 256x128x128 - arch filter rejection
]
# Wave configurations per tile
# Key: (tile_m, tile_n, tile_k) -> (wave_m, wave_n, wave_k)
# Constraint: tile_m == wave_m × warp_tile_m
# Only use approved wave configs from arch_specs.json: [1,4,1], [2,2,1], [4,1,1], [8,2,1], [4,4,1]
TILE_TO_WAVE: Dict[Tuple[int, int, int], Tuple[int, int, int]] = {
# warp_tile [16,16,16]
(16, 64, 64): (1, 4, 1),
(32, 64, 64): (2, 2, 1),
(64, 64, 64): (4, 1, 1),
# warp_tile [32,32,16]
(32, 128, 64): (1, 4, 1),
(64, 128, 64): (2, 2, 1),
(128, 128, 64): (4, 4, 1), # NEW - balanced 4x4 wave
# warp_tile [16,16,32]
(16, 64, 128): (1, 4, 1),
(32, 64, 128): (2, 2, 1),
(64, 64, 128): (4, 1, 1),
(128, 64, 128): (8, 2, 1), # NEW
}
# Warp tile configurations (must match arch_specs.json gfx950 bf16 approved list)
# Key: (tile_m, tile_n, tile_k) -> (warp_m, warp_n, warp_k)
TILE_TO_WARP: Dict[Tuple[int, int, int], Tuple[int, int, int]] = {
# warp_tile [16,16,16]
(16, 64, 64): (16, 16, 16),
(32, 64, 64): (16, 16, 16),
(64, 64, 64): (16, 16, 16),
# warp_tile [32,32,16]
(32, 128, 64): (32, 32, 16),
(64, 128, 64): (32, 32, 16),
(128, 128, 64): (32, 32, 16), # NEW
# warp_tile [16,16,32]
(16, 64, 128): (16, 16, 32),
(32, 64, 128): (16, 16, 32),
(64, 64, 128): (16, 16, 32),
(128, 64, 128): (16, 16, 32), # NEW
}
# Vector sizes per tile (for memory operations)
# Key: (tile_m, tile_n, tile_k) -> (vec_a, vec_b, vec_c)
TILE_TO_VECTOR: Dict[Tuple[int, int, int], Tuple[int, int, int]] = {
(16, 64, 64): (4, 8, 8),
(32, 64, 64): (4, 8, 8),
(64, 64, 64): (4, 8, 8),
(32, 128, 64): (4, 8, 8),
(64, 128, 64): (4, 8, 8),
(128, 128, 64): (4, 8, 8),
(16, 64, 128): (4, 8, 8),
(32, 64, 128): (4, 8, 8),
(64, 64, 128): (4, 8, 8),
(128, 64, 128): (4, 8, 8),
}
# =============================================================================
# Pipeline Variant Suffixes (single source of truth)
# =============================================================================
# Empirically verified valid (pipeline, wave_mode, has_dsb, has_si) combinations
# observed in the 2D and 3D bf16 gfx950 benchmark CSVs. 30 entries total per ndim.
# Each tuple: (pipeline, wave_mode, has_dsb, has_si)
# wave_mode: "intrawave" | "interwave"
# has_dsb: 1 if "_dsb" suffix present (double smem buffer), else 0
# has_si: 1 if "_si" suffix present (store immediate), else 0
PIPELINE_VARIANTS: List[Tuple[str, str, int, int]] = [
# basic_v1: both intra/inter × {∅, dsb, si, dsb_si} = 8 combos
("basic_v1", "intrawave", 0, 0),
("basic_v1", "intrawave", 1, 0),
("basic_v1", "intrawave", 0, 1),
("basic_v1", "intrawave", 1, 1),
("basic_v1", "interwave", 0, 0),
("basic_v1", "interwave", 1, 0),
("basic_v1", "interwave", 0, 1),
("basic_v1", "interwave", 1, 1),
# compv3: intrawave × {∅, dsb, si, dsb_si} = 4 combos
("compv3", "intrawave", 0, 0),
("compv3", "intrawave", 1, 0),
("compv3", "intrawave", 0, 1),
("compv3", "intrawave", 1, 1),
# compv4: intrawave × {dsb, dsb_si} only = 2 combos
("compv4", "intrawave", 1, 0),
("compv4", "intrawave", 1, 1),
# compv5: intrawave × {∅, dsb, si, dsb_si} = 4 combos
("compv5", "intrawave", 0, 0),
("compv5", "intrawave", 1, 0),
("compv5", "intrawave", 0, 1),
("compv5", "intrawave", 1, 1),
# compv6: intrawave × {∅, dsb, si, dsb_si} = 4 combos
("compv6", "intrawave", 0, 0),
("compv6", "intrawave", 1, 0),
("compv6", "intrawave", 0, 1),
("compv6", "intrawave", 1, 1),
# mem: both intra/inter × {∅, dsb, si, dsb_si} = 8 combos
("mem", "intrawave", 0, 0),
("mem", "intrawave", 1, 0),
("mem", "intrawave", 0, 1),
("mem", "intrawave", 1, 1),
("mem", "interwave", 0, 0),
("mem", "interwave", 1, 0),
("mem", "interwave", 0, 1),
("mem", "interwave", 1, 1),
]
def iter_pipeline_variants(pipelines: List[str] = None):
"""Iterate (pipeline, wave_mode, has_dsb, has_si) tuples, optionally filtered.
Args:
pipelines: optional list of pipeline names to keep. If None, yield all.
"""
if pipelines is None:
for entry in PIPELINE_VARIANTS:
yield entry
return
keep = set(pipelines)
for entry in PIPELINE_VARIANTS:
if entry[0] in keep:
yield entry
# Valid pipelines per variant
# All 8 pipelines (basic_v1, mem, compv3-6, comp_async, basic_async_v1) successfully
# build and run for all variants in both 2D and 3D (verified via 10_test_all_pipelines.py)
VARIANT_PIPELINES: Dict[str, List[str]] = {
"forward": [
"basic_v1",
"mem",
"compv3",
"compv4",
"compv5",
"compv6",
"comp_async",
"basic_async_v1",
],
"bwd_data": [
"basic_v1",
"mem",
"compv3",
"compv4",
"compv5",
"compv6",
"comp_async",
"basic_async_v1",
],
"bwd_weight": [
"basic_v1",
"mem",
"compv3",
"compv4",
"compv5",
"compv6",
"comp_async",
"basic_async_v1",
],
}
# Tiles that support compv4 pipeline
# compv4 has stricter requirements due to double buffering and LDS constraints
# Pattern: only warp_tile [16,16,16] or [16,16,32] work with compv4
# Large warp_tile [32,32,16] and wave [8,2,1] fail arch validation for compv4
COMPV4_COMPATIBLE_TILES: List[Tuple[int, int, int]] = [
# warp_tile [16,16,16] - all work with compv4
(16, 64, 64),
(32, 64, 64),
(64, 64, 64),
# (128, 64, 64), # Excluded: wave=8x2x1 fails for compv4
# warp_tile [16,16,32] - all work with compv4
(16, 64, 128),
(32, 64, 128),
(64, 64, 128),
# (128, 64, 128), # Excluded: wave=8x2x1 fails for compv4
]
# Backward weight tiles (very restricted due to transpose_tile2d constraints)
# Testing all tiles to verify which ones actually work
BWD_WEIGHT_TILES: List[Tuple[int, int, int]] = [
# warp_tile [16,16,16]
(16, 64, 64), # Known working config
(32, 64, 64), # Test
(64, 64, 64), # Test
# warp_tile [32,32,16]
(32, 128, 64), # Test
(64, 128, 64), # Test
(128, 128, 64), # Test
# warp_tile [16,16,32]
(16, 64, 128), # Test
(32, 64, 128), # Test
(64, 64, 128), # Test
(128, 64, 128), # Test
]
# =============================================================================
# Shared Validation Rules
# =============================================================================
# These functions are the single source of truth for validation rules
# for onvolution code generation.
# --- Vector size validation ---
WARP_SIZE = 64
def is_valid_vector_size(vec: int) -> bool:
"""AMD GPUs only support vector widths 1, 2, 4, 8, 16."""
return vec == 1 or vec % 2 == 0
def check_vectors(vec_a: int, vec_b: int, vec_c: int) -> bool:
"""Check all three vector sizes are valid (1 or even)."""
return all(is_valid_vector_size(v) for v in (vec_a, vec_b, vec_c))
# --- Tile coverage validation ---
def check_warp_coverage(
tile_m: int, tile_n: int, tile_k: int,
vec_a: int, vec_b: int,
variant: str = "forward",
) -> bool:
"""Check tile dims don't exceed single-warp vector load coverage.
The A-tile dimension is direction-aware:
Forward / bwd_weight: tile_m is the A-tile dim
Backward data: tile_k is the A-tile dim
"""
a_tile_dim = tile_k if variant == "bwd_data" else tile_m
if a_tile_dim > WARP_SIZE * vec_a:
return False
if tile_n > WARP_SIZE * vec_b:
return False
return True
def check_bwd_data_vec_coverage(
tile_m: int, tile_n: int, tile_k: int,
warp_m: int, warp_n: int, warp_k: int,
vec_a: int, vec_b: int,
) -> bool:
"""Bwd_data: vector width must not exceed elements per thread per tile slice."""
block_size = WARP_SIZE * warp_m * warp_n * warp_k
if vec_a > (tile_m * tile_k) // block_size:
return False
if vec_b > (tile_n * tile_k) // block_size:
return False
return True
# --- Pipeline-scheduler restrictions ---
INTERWAVE_PIPELINES = {"basic_v1", "mem"} # Only these support interwave
def is_valid_pipeline_scheduler(pipeline: str, scheduler: str) -> bool:
"""Check pipeline+scheduler combo is valid.
Only 'mem' and 'basic_v1' pipelines support interwave; all compute
pipelines (compv3/v4/v5/v6/async) only support intrawave.
"""
if scheduler == "interwave" and pipeline not in INTERWAVE_PIPELINES:
return False
return True
# --- Pipeline-variant restrictions ---
UNSUPPORTED_VARIANT_PIPELINES = {
"bwd_weight": {"compv5"},
"bwd_data": {"compv5"},
}
def is_valid_pipeline_for_variant(pipeline: str, variant: str) -> bool:
"""Check pipeline is supported for the given conv variant.
Backward weight and backward data reject compv5 due to transpose_tile2d /
get_length issues.
"""
blocked = UNSUPPORTED_VARIANT_PIPELINES.get(variant, set())
return pipeline not in blocked
# --- Stream-K restrictions ---
def is_streamk_valid_for_variant(variant: str) -> bool:
"""Stream-K is only supported for backward weight."""
return variant == "bwd_weight"
# =============================================================================
# Tile Registration Validation
# =============================================================================
def validate_tile_config(tile_m: int, tile_n: int, tile_k: int) -> bool:
"""Check if a tile configuration is valid and registered."""
tile_key = (tile_m, tile_n, tile_k)
return (
tile_key in TILE_TO_WAVE
and tile_key in TILE_TO_WARP
and tile_key in TILE_TO_VECTOR
)
def get_tile_full_config(tile_m: int, tile_n: int, tile_k: int) -> dict:
"""Get complete configuration for a tile size.
Returns:
dict with keys: wave_m, wave_n, wave_k, warp_m, warp_n, warp_k, vec_a, vec_b, vec_c
or None if tile not found
"""
tile_key = (tile_m, tile_n, tile_k)
if not validate_tile_config(tile_m, tile_n, tile_k):
return None
wave_m, wave_n, wave_k = TILE_TO_WAVE[tile_key]
warp_m, warp_n, warp_k = TILE_TO_WARP[tile_key]
vec_a, vec_b, vec_c = TILE_TO_VECTOR[tile_key]
return {
"tile_m": tile_m,
"tile_n": tile_n,
"tile_k": tile_k,
"wave_m": wave_m,
"wave_n": wave_n,
"wave_k": wave_k,
"warp_m": warp_m,
"warp_n": warp_n,
"warp_k": warp_k,
"vec_a": vec_a,
"vec_b": vec_b,
"vec_c": vec_c,
}
# =============================================================================
# Summary Statistics
# =============================================================================
def print_summary():
"""Print summary of available tile configurations."""
print("=" * 80)
print("Grouped Convolution Tile Configurations (Single Source of Truth)")
print("=" * 80)
print(f"Total tiles: {len(COMMON_TILES)}")
print(f"Backward weight tiles: {len(BWD_WEIGHT_TILES)}")
print()
print("Tile sizes (M×N×K):")
for tile in COMMON_TILES:
m, n, k = tile
wave = TILE_TO_WAVE[tile]
warp = TILE_TO_WARP[tile]
print(
f" {m:3}×{n:3}×{k:3} wave={wave[0]}×{wave[1]}×{wave[2]} warp={warp[0]}×{warp[1]}×{warp[2]}"
)
print("=" * 80)
if __name__ == "__main__":
print_summary()