Files
composable_kernel/dispatcher/python
Yaswanth Raparti 470ff04817 [rocm-libraries] ROCm/rocm-libraries#6445 (commit 2225e10)
[CK][CK_TILE] Fix library caching bug in gemm dispatcher
 (#6445)

## Motivation

setup_gemm_dispatcher() was rebuilding libraries on every call instead
of reusing cached libraries.

**Root Cause**:
1. Library names only included dtype+layout, causing different
tile/wave/warp configs to overwrite each other
2. No cache checking - always loaded default library, detected mismatch,
then rebuilt

## Technical Details

**Solution**:
1. Complete library naming with all distinguishing parameters:
libdispatcher_gemm_{dtype}_{layout}_{tile}_{wave}_{warp}_{pipeline}_{epilogue}_{scheduler}.so

2. Cache checking before rebuild:
   - Check if library for exact config already exists
   - Reuse if found (500x faster: 0.02s vs 10s)
   - Only rebuild when no cached library exists

3. Better error handling for kernel generation failures

Files Changed:
- dispatcher/python/ctypes_utils.py
- dispatcher/tests/test_library_caching.py (new unit test)

## Test Plan

Use `dispatcher/tests/test_library_caching.py ` to ensure that libraries
are cached and only rebuilt if they are not present in build directory

1. **test_01_unique_library_naming** - Library names include all
parameters (dtype, layout, tile, wave, warp, pipeline, epilogue,
scheduler)
2. **test_02_library_build_and_cache** - Libraries are built once and
then cached for reuse
3. **test_03_different_configs_different_libraries** - Different configs
create different library files
4. **test_04_cache_message_verification** - Cache hit messages are
logged correctly
5. **test_05_code_fix_verification** - Code changes are present in
ctypes_utils.py

## Test Result
All the test above passed.

## Submission Checklist

- [ x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
2026-04-16 01:07:37 +00:00
..

CK Tile Dispatcher Python Utilities

This directory contains Python utilities used by the dispatcher examples.

Contents

Shared Utilities (used by both GEMM and Grouped Conv)

  • dispatcher_common.py - Shared dispatcher infrastructure
    • Path helpers (get_dispatcher_root, get_build_dir, etc.)
    • ValidationResultBase - Structured validation feedback
    • validate_wave_config, validate_warp_tile_config, validate_trait_combo
    • auto_correct_wave, auto_correct_trait - Auto-correction helpers
    • Colors - Cross-platform ANSI color support
    • print_phase, print_success, print_error, print_info - Phased output
    • cleanup_generated_kernels - Cleanup helper

GEMM Utilities

  • ctypes_utils.py - Core ctypes utilities for GEMM Python examples
    • KernelConfig - Kernel configuration dataclass
    • setup_gemm_dispatcher() - Setup dispatcher with auto-correction
    • cleanup_gemm() - Cleanup dispatcher resources
    • GemmRunner - GPU execution helper
    • Auto-correction and validation utilities

Grouped Convolution Utilities

  • grouped_conv_utils.py - Utilities for grouped convolution
    • GroupedConvValidationResult - Validation result (extends ValidationResultBase)
    • validate_grouped_conv_config - Validate a grouped conv config
    • auto_correct_grouped_conv_config - Auto-correct invalid configs
    • get_grouped_conv_default_config - Get default config for a variant
    • GroupedConvDataType - Data type enum (FP16, BF16, FP32, FP8, BF8, INT8)
    • format_grouped_conv_summary - Human-readable config summary

Usage

GEMM Examples

The GEMM Python examples in dispatcher/examples/gemm/python/ import:

import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))

from ctypes_utils import (
    KernelConfig,
    setup_gemm_dispatcher,
    cleanup_gemm,
    GemmRunner,
)

Grouped Conv Usage

import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))

from grouped_conv_utils import (
    validate_grouped_conv_config,
    auto_correct_grouped_conv_config,
    get_grouped_conv_default_config,
    GroupedConvDataType,
)

# Get a default config
config = get_grouped_conv_default_config(variant="forward", arch="gfx942")

# Validate
result = validate_grouped_conv_config(config)
print(f"Valid: {result.is_valid}")

Requirements

  • Python 3.8+
  • NumPy
  • HIP runtime (for GPU execution)