mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
[CK][CK_TILE] Fix library caching bug in gemm dispatcher (#6445) ## Motivation setup_gemm_dispatcher() was rebuilding libraries on every call instead of reusing cached libraries. **Root Cause**: 1. Library names only included dtype+layout, causing different tile/wave/warp configs to overwrite each other 2. No cache checking - always loaded default library, detected mismatch, then rebuilt ## Technical Details **Solution**: 1. Complete library naming with all distinguishing parameters: libdispatcher_gemm_{dtype}_{layout}_{tile}_{wave}_{warp}_{pipeline}_{epilogue}_{scheduler}.so 2. Cache checking before rebuild: - Check if library for exact config already exists - Reuse if found (500x faster: 0.02s vs 10s) - Only rebuild when no cached library exists 3. Better error handling for kernel generation failures Files Changed: - dispatcher/python/ctypes_utils.py - dispatcher/tests/test_library_caching.py (new unit test) ## Test Plan Use `dispatcher/tests/test_library_caching.py ` to ensure that libraries are cached and only rebuilt if they are not present in build directory 1. **test_01_unique_library_naming** - Library names include all parameters (dtype, layout, tile, wave, warp, pipeline, epilogue, scheduler) 2. **test_02_library_build_and_cache** - Libraries are built once and then cached for reuse 3. **test_03_different_configs_different_libraries** - Different configs create different library files 4. **test_04_cache_message_verification** - Cache hit messages are logged correctly 5. **test_05_code_fix_verification** - Code changes are present in ctypes_utils.py ## Test Result All the test above passed. ## Submission Checklist - [ x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
CK Tile Dispatcher Python Utilities
This directory contains Python utilities used by the dispatcher examples.
Contents
Shared Utilities (used by both GEMM and Grouped Conv)
dispatcher_common.py- Shared dispatcher infrastructure- Path helpers (
get_dispatcher_root,get_build_dir, etc.) ValidationResultBase- Structured validation feedbackvalidate_wave_config,validate_warp_tile_config,validate_trait_comboauto_correct_wave,auto_correct_trait- Auto-correction helpersColors- Cross-platform ANSI color supportprint_phase,print_success,print_error,print_info- Phased outputcleanup_generated_kernels- Cleanup helper
- Path helpers (
GEMM Utilities
ctypes_utils.py- Core ctypes utilities for GEMM Python examplesKernelConfig- Kernel configuration dataclasssetup_gemm_dispatcher()- Setup dispatcher with auto-correctioncleanup_gemm()- Cleanup dispatcher resourcesGemmRunner- GPU execution helper- Auto-correction and validation utilities
Grouped Convolution Utilities
grouped_conv_utils.py- Utilities for grouped convolutionGroupedConvValidationResult- Validation result (extendsValidationResultBase)validate_grouped_conv_config- Validate a grouped conv configauto_correct_grouped_conv_config- Auto-correct invalid configsget_grouped_conv_default_config- Get default config for a variantGroupedConvDataType- Data type enum (FP16, BF16, FP32, FP8, BF8, INT8)format_grouped_conv_summary- Human-readable config summary
Usage
GEMM Examples
The GEMM Python examples in dispatcher/examples/gemm/python/ import:
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
from ctypes_utils import (
KernelConfig,
setup_gemm_dispatcher,
cleanup_gemm,
GemmRunner,
)
Grouped Conv Usage
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
from grouped_conv_utils import (
validate_grouped_conv_config,
auto_correct_grouped_conv_config,
get_grouped_conv_default_config,
GroupedConvDataType,
)
# Get a default config
config = get_grouped_conv_default_config(variant="forward", arch="gfx942")
# Validate
result = validate_grouped_conv_config(config)
print(f"Valid: {result.is_valid}")
Requirements
- Python 3.8+
- NumPy
- HIP runtime (for GPU execution)