diff --git a/dispatcher/python/ctypes_utils.py b/dispatcher/python/ctypes_utils.py index c11aaca835..d719d1405e 100644 --- a/dispatcher/python/ctypes_utils.py +++ b/dispatcher/python/ctypes_utils.py @@ -1946,8 +1946,16 @@ class CodegenRunner: Returns: Path to new library, or None on failure """ build_dir = get_build_dir() - # Use unique filename based on dtype/layout to avoid overwriting loaded library - lib_name = f"libdispatcher_gemm_{config.dtype_a}_{config.layout}_lib.so" + # Use unique filename based on ALL distinguishing config parameters + # Include: dtype, layout, tile, wave, warp, pipeline, epilogue, scheduler + # This ensures different configs don't collide even if tile/pipeline match + wave_str = f"{config.wave_m}x{config.wave_n}x{config.wave_k}" + warp_str = f"{config.warp_m}x{config.warp_n}x{config.warp_k}" + lib_name = ( + f"libdispatcher_gemm_{config.dtype_a}_{config.layout}_" + f"{config.tile_str}_{wave_str}_{warp_str}_" + f"{config.pipeline}_{config.epilogue}_{config.scheduler}.so" + ) lib_path = build_dir / "examples" / lib_name print(f" Rebuilding library: {lib_name}") @@ -2548,29 +2556,66 @@ def setup_gemm_dispatcher( if needs_rebuild and auto_rebuild: log(f" Library kernel doesn't match config: {', '.join(mismatches)}") - log(" Rebuilding library for exact config match...") - # First ensure we have a kernel header for this exact config - if not kernel_header: - # Generate kernel for the exact config - log(" Generating kernel for config...") - codegen_result = codegen.generate_from_config(config, force=True) - kernel_header = find_matching_kernel_header(config) - result.kernel_header = kernel_header + # Check if a rebuilt library for this exact config already exists + build_dir = get_build_dir() + wave_str = f"{config.wave_m}x{config.wave_n}x{config.wave_k}" + warp_str = f"{config.warp_m}x{config.warp_n}x{config.warp_k}" + cached_lib_name = ( + f"libdispatcher_gemm_{config.dtype_a}_{config.layout}_" + f"{config.tile_str}_{wave_str}_{warp_str}_" + f"{config.pipeline}_{config.epilogue}_{config.scheduler}.so" + ) + cached_lib_path = build_dir / "examples" / cached_lib_name - if kernel_header: - new_lib_path = codegen._rebuild_library_for_config(config, kernel_header) - if new_lib_path: - lib = DispatcherLib.load(new_lib_path) - if lib is None or not lib.initialize(): - result.error = "Failed to load rebuilt library" - return result + if cached_lib_path.exists(): + log(f" Using cached library: {cached_lib_name}") + lib = DispatcherLib.load(cached_lib_path) + if lib is not None and lib.initialize(): result.lib = lib - log(f" OK Rebuilt library: {lib.get_kernel_name()}") + log(f" OK Loaded cached library: {lib.get_kernel_name()}") else: - log(" WARNING Rebuild failed, using existing library") + log(" WARNING Cached library failed to load/initialize") + cached_lib_path = None # Force rebuild else: - log(" WARNING No kernel header found for config, using existing library") + log(" Rebuilding library for exact config match...") + + # First ensure we have a kernel header for this exact config + if not kernel_header: + # Generate kernel for the exact config + log(" Generating kernel for config...") + codegen_result = codegen.generate_from_config(config, force=True) + + # Check if generation succeeded + if not codegen_result.success: + log(f" WARNING Kernel generation failed:") + if codegen_result.stderr: + # Show first few lines of error + error_lines = codegen_result.stderr.split('\n')[:5] + for line in error_lines: + if line.strip(): + log(f" {line}") + log(" This config may not be valid for the target architecture") + log(" Falling back to existing library") + # Don't try to rebuild without a valid kernel + kernel_header = None + else: + kernel_header = find_matching_kernel_header(config) + result.kernel_header = kernel_header + + if kernel_header: + new_lib_path = codegen._rebuild_library_for_config(config, kernel_header) + if new_lib_path: + lib = DispatcherLib.load(new_lib_path) + if lib is None or not lib.initialize(): + result.error = "Failed to load rebuilt library" + return result + result.lib = lib + log(f" OK Rebuilt library: {lib.get_kernel_name()}") + else: + log(" WARNING Rebuild failed, using existing library") + else: + log(" WARNING No kernel header found for config, using existing library") # Step 5: Create registry and dispatcher log(" Creating registry and dispatcher...") diff --git a/dispatcher/tests/test_library_caching.py b/dispatcher/tests/test_library_caching.py new file mode 100755 index 0000000000..13d3407f44 --- /dev/null +++ b/dispatcher/tests/test_library_caching.py @@ -0,0 +1,294 @@ +#!/usr/bin/env python3 +""" +Unit tests for library caching in setup_gemm_dispatcher(). + +Tests verify that: +1. Different kernel configs create unique library files with complete naming +2. Repeated configs reuse cached libraries (no redundant rebuilds) +3. Library names include all distinguishing parameters (dtype, layout, tile, wave, warp, pipeline, epilogue, scheduler) +4. Kernel headers are generated when missing +""" + +import sys +import time +import unittest +from pathlib import Path + +# Add dispatcher python to path +DISPATCHER_ROOT = Path(__file__).parent.parent +sys.path.insert(0, str(DISPATCHER_ROOT / "python")) + +from ctypes_utils import ( + setup_gemm_dispatcher, + KernelConfig, + get_build_dir, +) + + +class TestLibraryCaching(unittest.TestCase): + """Test library caching functionality in setup_gemm_dispatcher""" + + @classmethod + def setUpClass(cls): + """Set up test environment once for all tests""" + cls.build_dir = get_build_dir() + cls.examples_dir = cls.build_dir / "examples" + + # Clean up any previous test libraries + cls._cleanup_test_libraries() + + @classmethod + def _cleanup_test_libraries(cls): + """Remove test library files""" + if cls.examples_dir.exists(): + for lib in cls.examples_dir.glob("libdispatcher_gemm_fp16_rcr_*_compv4_*.so"): + try: + lib.unlink() + except Exception: + pass + + def test_01_unique_library_naming(self): + """Test that library names include all distinguishing parameters""" + config = KernelConfig( + dtype_a="fp16", + layout_a="row", + layout_b="col", + layout_c="row", + tile_m=128, + tile_n=128, + tile_k=64, + pipeline="compv4", + gfx_arch="gfx950", + ) + + result = setup_gemm_dispatcher(config, verbose=False, auto_rebuild=True) + + self.assertTrue(result.success, "setup_gemm_dispatcher should succeed") + self.assertIsNotNone(result.lib, "Library should be loaded") + + lib_name = result.lib.path.name + + # Verify library name includes all parameters + self.assertIn("fp16", lib_name, "Library name should include dtype") + self.assertIn("rcr", lib_name, "Library name should include layout") + self.assertIn("128x128x64", lib_name, "Library name should include tile dimensions") + self.assertIn("2x2x1", lib_name, "Library name should include wave dimensions") + self.assertIn("32x32x16", lib_name, "Library name should include warp dimensions") + self.assertIn("compv4", lib_name, "Library name should include pipeline") + self.assertIn("cshuffle", lib_name, "Library name should include epilogue") + self.assertIn("intrawave", lib_name, "Library name should include scheduler") + + print(f"✓ Library name includes all parameters: {lib_name}") + + def test_02_library_build_and_cache(self): + """Test that libraries are built correctly and then cached""" + config = KernelConfig( + dtype_a="fp16", + layout_a="row", + layout_b="col", + layout_c="row", + tile_m=128, + tile_n=128, + tile_k=64, + pipeline="compv4", + gfx_arch="gfx950", + ) + + expected_lib_name = "libdispatcher_gemm_fp16_rcr_128x128x64_2x2x1_32x32x16_compv4_cshuffle_intrawave.so" + expected_lib_path = self.examples_dir / expected_lib_name + + # First call - should build library + start_time = time.time() + result1 = setup_gemm_dispatcher(config, verbose=False, auto_rebuild=True) + time1 = time.time() - start_time + + self.assertTrue(result1.success, "First setup should succeed") + + # Check if library was created (might use default if config matches) + if expected_lib_path.exists(): + lib_created = True + print(f"✓ Library created: {expected_lib_name}") + else: + # Config might match default library, which is also valid + lib_created = False + print(f" Config matches default library: {result1.lib.path.name}") + + # Second call - should use cache if library was built + start_time = time.time() + result2 = setup_gemm_dispatcher(config, verbose=False, auto_rebuild=True) + time2 = time.time() - start_time + + self.assertTrue(result2.success, "Second setup should succeed") + + # If library was created, second call should be much faster (cached) + if lib_created and time1 > 5.0: # First call took significant time (build happened) + self.assertLess(time2, time1 * 0.5, + f"Cached load ({time2:.2f}s) should be much faster than build ({time1:.2f}s)") + print(f"✓ Cache reuse: {time2:.2f}s vs {time1:.2f}s ({time1/time2:.1f}x faster)") + else: + print(f" Both calls fast (using default library)") + + def test_03_different_configs_different_libraries(self): + """Test that different configs create different library files""" + configs = [ + KernelConfig( + dtype_a="fp16", + layout_a="row", + layout_b="col", + layout_c="row", + tile_m=128, + tile_n=128, + tile_k=64, + pipeline="compv4", + gfx_arch="gfx950", + ), + KernelConfig( + dtype_a="fp16", + layout_a="row", + layout_b="col", + layout_c="row", + tile_m=128, + tile_n=128, + tile_k=32, + pipeline="compv4", + gfx_arch="gfx950", + ), + ] + + results = [] + for i, config in enumerate(configs): + result = setup_gemm_dispatcher( + config, + registry_name=f"test_registry_{i}", + verbose=False, + auto_rebuild=True + ) + results.append(result) + + # Check that all setups succeeded + for i, result in enumerate(results): + self.assertTrue(result.success, f"Setup {i+1} should succeed") + + # Check that different configs loaded different libraries (if both built custom libs) + lib_names = [r.lib.path.name for r in results if r.lib] + + # If both created custom libraries, they should be different + custom_libs = [name for name in lib_names if "libdispatcher_gemm_fp16_rcr_128x128" in name + and name != "libdispatcher_gemm_lib.so"] + + if len(custom_libs) >= 2: + # Should have different tile dimensions in names + self.assertNotEqual(custom_libs[0], custom_libs[1], + "Different configs should create different libraries") + self.assertIn("128x128x64", custom_libs[0]) + self.assertIn("128x128x32", custom_libs[1]) + print(f"✓ Different configs created different libraries:") + for lib in custom_libs: + print(f" - {lib}") + else: + print(f" Configs used default library (valid when configs match default)") + + def test_04_cache_message_verification(self): + """Test that cache hit messages are logged correctly""" + config = KernelConfig( + dtype_a="fp16", + layout_a="row", + layout_b="col", + layout_c="row", + tile_m=128, + tile_n=128, + tile_k=64, + pipeline="compv4", + gfx_arch="gfx950", + ) + + # First call + result1 = setup_gemm_dispatcher(config, verbose=False, auto_rebuild=True) + self.assertTrue(result1.success) + + # Second call - capture output to check for cache message + import io + from contextlib import redirect_stdout + + f = io.StringIO() + with redirect_stdout(f): + result2 = setup_gemm_dispatcher(config, verbose=True, auto_rebuild=True) + + output = f.getvalue() + + self.assertTrue(result2.success) + + # Check if cache was used (either message appears or default lib was used) + if "Using cached library" in output: + print("✓ Cache hit message logged correctly") + self.assertIn("Using cached library", output) + elif "libdispatcher_gemm_lib.so" in str(result2.lib.path): + print(" Using default CMake library (no rebuild needed)") + else: + print(" Warning: Expected cache message not found (may have rebuilt)") + + def test_05_code_fix_verification(self): + """Verify the code changes are in place""" + from ctypes_utils import get_dispatcher_root + + ctypes_utils_path = get_dispatcher_root() / "python" / "ctypes_utils.py" + self.assertTrue(ctypes_utils_path.exists(), "ctypes_utils.py should exist") + + with open(ctypes_utils_path, 'r') as f: + code = f.read() + + # Check Fix #1: Complete library naming + self.assertIn( + "_{config.pipeline}_{config.epilogue}_{config.scheduler}", + code, + "Library naming should include pipeline, epilogue, and scheduler" + ) + self.assertIn( + "_{wave_str}_{warp_str}_", + code, + "Library naming should include wave and warp dimensions" + ) + + # Check Fix #2: Cache checking logic + self.assertIn( + "cached_lib_path.exists()", + code, + "Cache checking logic should be present" + ) + self.assertIn( + "Using cached library", + code, + "Cache hit message should be present" + ) + + print("✓ Code fixes verified:") + print(" - Complete library naming (dtype, layout, tile, wave, warp, pipeline, epilogue, scheduler)") + print(" - Cache checking logic present") + + +def run_tests(verbosity=2): + """Run all tests with specified verbosity""" + loader = unittest.TestLoader() + suite = loader.loadTestsFromTestCase(TestLibraryCaching) + runner = unittest.TextTestRunner(verbosity=verbosity) + result = runner.run(suite) + return 0 if result.wasSuccessful() else 1 + + +if __name__ == "__main__": + print("="*80) + print(" Library Caching Unit Tests") + print("="*80) + print() + + exit_code = run_tests(verbosity=2) + + print() + print("="*80) + if exit_code == 0: + print(" ✓ ALL TESTS PASSED") + else: + print(" ✗ SOME TESTS FAILED") + print("="*80) + + sys.exit(exit_code)