mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[CK][CK_TILE] Fix library caching bug in gemm dispatcher (#6445) ## Motivation setup_gemm_dispatcher() was rebuilding libraries on every call instead of reusing cached libraries. **Root Cause**: 1. Library names only included dtype+layout, causing different tile/wave/warp configs to overwrite each other 2. No cache checking - always loaded default library, detected mismatch, then rebuilt ## Technical Details **Solution**: 1. Complete library naming with all distinguishing parameters: libdispatcher_gemm_{dtype}_{layout}_{tile}_{wave}_{warp}_{pipeline}_{epilogue}_{scheduler}.so 2. Cache checking before rebuild: - Check if library for exact config already exists - Reuse if found (500x faster: 0.02s vs 10s) - Only rebuild when no cached library exists 3. Better error handling for kernel generation failures Files Changed: - dispatcher/python/ctypes_utils.py - dispatcher/tests/test_library_caching.py (new unit test) ## Test Plan Use `dispatcher/tests/test_library_caching.py ` to ensure that libraries are cached and only rebuilt if they are not present in build directory 1. **test_01_unique_library_naming** - Library names include all parameters (dtype, layout, tile, wave, warp, pipeline, epilogue, scheduler) 2. **test_02_library_build_and_cache** - Libraries are built once and then cached for reuse 3. **test_03_different_configs_different_libraries** - Different configs create different library files 4. **test_04_cache_message_verification** - Cache hit messages are logged correctly 5. **test_05_code_fix_verification** - Code changes are present in ctypes_utils.py ## Test Result All the test above passed. ## Submission Checklist - [ x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
295 lines
10 KiB
Python
Executable File
295 lines
10 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""
|
|
Unit tests for library caching in setup_gemm_dispatcher().
|
|
|
|
Tests verify that:
|
|
1. Different kernel configs create unique library files with complete naming
|
|
2. Repeated configs reuse cached libraries (no redundant rebuilds)
|
|
3. Library names include all distinguishing parameters (dtype, layout, tile, wave, warp, pipeline, epilogue, scheduler)
|
|
4. Kernel headers are generated when missing
|
|
"""
|
|
|
|
import sys
|
|
import time
|
|
import unittest
|
|
from pathlib import Path
|
|
|
|
# Add dispatcher python to path
|
|
DISPATCHER_ROOT = Path(__file__).parent.parent
|
|
sys.path.insert(0, str(DISPATCHER_ROOT / "python"))
|
|
|
|
from ctypes_utils import (
|
|
setup_gemm_dispatcher,
|
|
KernelConfig,
|
|
get_build_dir,
|
|
)
|
|
|
|
|
|
class TestLibraryCaching(unittest.TestCase):
|
|
"""Test library caching functionality in setup_gemm_dispatcher"""
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
"""Set up test environment once for all tests"""
|
|
cls.build_dir = get_build_dir()
|
|
cls.examples_dir = cls.build_dir / "examples"
|
|
|
|
# Clean up any previous test libraries
|
|
cls._cleanup_test_libraries()
|
|
|
|
@classmethod
|
|
def _cleanup_test_libraries(cls):
|
|
"""Remove test library files"""
|
|
if cls.examples_dir.exists():
|
|
for lib in cls.examples_dir.glob("libdispatcher_gemm_fp16_rcr_*_compv4_*.so"):
|
|
try:
|
|
lib.unlink()
|
|
except Exception:
|
|
pass
|
|
|
|
def test_01_unique_library_naming(self):
|
|
"""Test that library names include all distinguishing parameters"""
|
|
config = KernelConfig(
|
|
dtype_a="fp16",
|
|
layout_a="row",
|
|
layout_b="col",
|
|
layout_c="row",
|
|
tile_m=128,
|
|
tile_n=128,
|
|
tile_k=64,
|
|
pipeline="compv4",
|
|
gfx_arch="gfx950",
|
|
)
|
|
|
|
result = setup_gemm_dispatcher(config, verbose=False, auto_rebuild=True)
|
|
|
|
self.assertTrue(result.success, "setup_gemm_dispatcher should succeed")
|
|
self.assertIsNotNone(result.lib, "Library should be loaded")
|
|
|
|
lib_name = result.lib.path.name
|
|
|
|
# Verify library name includes all parameters
|
|
self.assertIn("fp16", lib_name, "Library name should include dtype")
|
|
self.assertIn("rcr", lib_name, "Library name should include layout")
|
|
self.assertIn("128x128x64", lib_name, "Library name should include tile dimensions")
|
|
self.assertIn("2x2x1", lib_name, "Library name should include wave dimensions")
|
|
self.assertIn("32x32x16", lib_name, "Library name should include warp dimensions")
|
|
self.assertIn("compv4", lib_name, "Library name should include pipeline")
|
|
self.assertIn("cshuffle", lib_name, "Library name should include epilogue")
|
|
self.assertIn("intrawave", lib_name, "Library name should include scheduler")
|
|
|
|
print(f"✓ Library name includes all parameters: {lib_name}")
|
|
|
|
def test_02_library_build_and_cache(self):
|
|
"""Test that libraries are built correctly and then cached"""
|
|
config = KernelConfig(
|
|
dtype_a="fp16",
|
|
layout_a="row",
|
|
layout_b="col",
|
|
layout_c="row",
|
|
tile_m=128,
|
|
tile_n=128,
|
|
tile_k=64,
|
|
pipeline="compv4",
|
|
gfx_arch="gfx950",
|
|
)
|
|
|
|
expected_lib_name = "libdispatcher_gemm_fp16_rcr_128x128x64_2x2x1_32x32x16_compv4_cshuffle_intrawave.so"
|
|
expected_lib_path = self.examples_dir / expected_lib_name
|
|
|
|
# First call - should build library
|
|
start_time = time.time()
|
|
result1 = setup_gemm_dispatcher(config, verbose=False, auto_rebuild=True)
|
|
time1 = time.time() - start_time
|
|
|
|
self.assertTrue(result1.success, "First setup should succeed")
|
|
|
|
# Check if library was created (might use default if config matches)
|
|
if expected_lib_path.exists():
|
|
lib_created = True
|
|
print(f"✓ Library created: {expected_lib_name}")
|
|
else:
|
|
# Config might match default library, which is also valid
|
|
lib_created = False
|
|
print(f" Config matches default library: {result1.lib.path.name}")
|
|
|
|
# Second call - should use cache if library was built
|
|
start_time = time.time()
|
|
result2 = setup_gemm_dispatcher(config, verbose=False, auto_rebuild=True)
|
|
time2 = time.time() - start_time
|
|
|
|
self.assertTrue(result2.success, "Second setup should succeed")
|
|
|
|
# If library was created, second call should be much faster (cached)
|
|
if lib_created and time1 > 5.0: # First call took significant time (build happened)
|
|
self.assertLess(time2, time1 * 0.5,
|
|
f"Cached load ({time2:.2f}s) should be much faster than build ({time1:.2f}s)")
|
|
print(f"✓ Cache reuse: {time2:.2f}s vs {time1:.2f}s ({time1/time2:.1f}x faster)")
|
|
else:
|
|
print(f" Both calls fast (using default library)")
|
|
|
|
def test_03_different_configs_different_libraries(self):
|
|
"""Test that different configs create different library files"""
|
|
configs = [
|
|
KernelConfig(
|
|
dtype_a="fp16",
|
|
layout_a="row",
|
|
layout_b="col",
|
|
layout_c="row",
|
|
tile_m=128,
|
|
tile_n=128,
|
|
tile_k=64,
|
|
pipeline="compv4",
|
|
gfx_arch="gfx950",
|
|
),
|
|
KernelConfig(
|
|
dtype_a="fp16",
|
|
layout_a="row",
|
|
layout_b="col",
|
|
layout_c="row",
|
|
tile_m=128,
|
|
tile_n=128,
|
|
tile_k=32,
|
|
pipeline="compv4",
|
|
gfx_arch="gfx950",
|
|
),
|
|
]
|
|
|
|
results = []
|
|
for i, config in enumerate(configs):
|
|
result = setup_gemm_dispatcher(
|
|
config,
|
|
registry_name=f"test_registry_{i}",
|
|
verbose=False,
|
|
auto_rebuild=True
|
|
)
|
|
results.append(result)
|
|
|
|
# Check that all setups succeeded
|
|
for i, result in enumerate(results):
|
|
self.assertTrue(result.success, f"Setup {i+1} should succeed")
|
|
|
|
# Check that different configs loaded different libraries (if both built custom libs)
|
|
lib_names = [r.lib.path.name for r in results if r.lib]
|
|
|
|
# If both created custom libraries, they should be different
|
|
custom_libs = [name for name in lib_names if "libdispatcher_gemm_fp16_rcr_128x128" in name
|
|
and name != "libdispatcher_gemm_lib.so"]
|
|
|
|
if len(custom_libs) >= 2:
|
|
# Should have different tile dimensions in names
|
|
self.assertNotEqual(custom_libs[0], custom_libs[1],
|
|
"Different configs should create different libraries")
|
|
self.assertIn("128x128x64", custom_libs[0])
|
|
self.assertIn("128x128x32", custom_libs[1])
|
|
print(f"✓ Different configs created different libraries:")
|
|
for lib in custom_libs:
|
|
print(f" - {lib}")
|
|
else:
|
|
print(f" Configs used default library (valid when configs match default)")
|
|
|
|
def test_04_cache_message_verification(self):
|
|
"""Test that cache hit messages are logged correctly"""
|
|
config = KernelConfig(
|
|
dtype_a="fp16",
|
|
layout_a="row",
|
|
layout_b="col",
|
|
layout_c="row",
|
|
tile_m=128,
|
|
tile_n=128,
|
|
tile_k=64,
|
|
pipeline="compv4",
|
|
gfx_arch="gfx950",
|
|
)
|
|
|
|
# First call
|
|
result1 = setup_gemm_dispatcher(config, verbose=False, auto_rebuild=True)
|
|
self.assertTrue(result1.success)
|
|
|
|
# Second call - capture output to check for cache message
|
|
import io
|
|
from contextlib import redirect_stdout
|
|
|
|
f = io.StringIO()
|
|
with redirect_stdout(f):
|
|
result2 = setup_gemm_dispatcher(config, verbose=True, auto_rebuild=True)
|
|
|
|
output = f.getvalue()
|
|
|
|
self.assertTrue(result2.success)
|
|
|
|
# Check if cache was used (either message appears or default lib was used)
|
|
if "Using cached library" in output:
|
|
print("✓ Cache hit message logged correctly")
|
|
self.assertIn("Using cached library", output)
|
|
elif "libdispatcher_gemm_lib.so" in str(result2.lib.path):
|
|
print(" Using default CMake library (no rebuild needed)")
|
|
else:
|
|
print(" Warning: Expected cache message not found (may have rebuilt)")
|
|
|
|
def test_05_code_fix_verification(self):
|
|
"""Verify the code changes are in place"""
|
|
from ctypes_utils import get_dispatcher_root
|
|
|
|
ctypes_utils_path = get_dispatcher_root() / "python" / "ctypes_utils.py"
|
|
self.assertTrue(ctypes_utils_path.exists(), "ctypes_utils.py should exist")
|
|
|
|
with open(ctypes_utils_path, 'r') as f:
|
|
code = f.read()
|
|
|
|
# Check Fix #1: Complete library naming
|
|
self.assertIn(
|
|
"_{config.pipeline}_{config.epilogue}_{config.scheduler}",
|
|
code,
|
|
"Library naming should include pipeline, epilogue, and scheduler"
|
|
)
|
|
self.assertIn(
|
|
"_{wave_str}_{warp_str}_",
|
|
code,
|
|
"Library naming should include wave and warp dimensions"
|
|
)
|
|
|
|
# Check Fix #2: Cache checking logic
|
|
self.assertIn(
|
|
"cached_lib_path.exists()",
|
|
code,
|
|
"Cache checking logic should be present"
|
|
)
|
|
self.assertIn(
|
|
"Using cached library",
|
|
code,
|
|
"Cache hit message should be present"
|
|
)
|
|
|
|
print("✓ Code fixes verified:")
|
|
print(" - Complete library naming (dtype, layout, tile, wave, warp, pipeline, epilogue, scheduler)")
|
|
print(" - Cache checking logic present")
|
|
|
|
|
|
def run_tests(verbosity=2):
|
|
"""Run all tests with specified verbosity"""
|
|
loader = unittest.TestLoader()
|
|
suite = loader.loadTestsFromTestCase(TestLibraryCaching)
|
|
runner = unittest.TextTestRunner(verbosity=verbosity)
|
|
result = runner.run(suite)
|
|
return 0 if result.wasSuccessful() else 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
print("="*80)
|
|
print(" Library Caching Unit Tests")
|
|
print("="*80)
|
|
print()
|
|
|
|
exit_code = run_tests(verbosity=2)
|
|
|
|
print()
|
|
print("="*80)
|
|
if exit_code == 0:
|
|
print(" ✓ ALL TESTS PASSED")
|
|
else:
|
|
print(" ✗ SOME TESTS FAILED")
|
|
print("="*80)
|
|
|
|
sys.exit(exit_code)
|