Files
composable_kernel/dispatcher/examples/gemm/python/11_json_import.py
Vidyasagar Ananthan 9e049a32a1 Adding dispatcher architecture (#3300)
* WIP POC of dispatcher

* Dispatcher python workflow setup.

* Dispatcher cleanup and updates.

Further dispatcher cleanup and updates.

Build fixes

Improvements and python to CK example

Improvements to readme

* Fixes to python paths

* Cleaning up code

* Improving dispatcher support for different arch

Fixing typos

* Fix formatting errors

* Cleaning up examples

* Improving codegeneration

* Improving and fixing C++ examples

* Adding conv functionality (fwd,bwd,bwdw) and examples.

* Fixes based on feedback.

* Further fixes based on feedback.

* Adding stress test for autogeneration and autocorrection, and fixing preshuffle bug.

* Another round of improvements  based on feedback.

* Trimming out unnecessary code.

* Fixing the multi-D implementation.

* Using gpu verification for gemms and fixing convolutions tflops calculation.

* Fix counter usage issue and arch filtering per ops.

* Adding changelog and other fixes.

* Improve examples and resolve critical bugs.

* Reduce build time for python examples.

* Fixing minor bug.

* Fix compilation error.

* Improve installation instructions for dispatcher.

* Add docker based  installation instructions for dispatcher.

* Fixing arch-based filtering to match tile engine.

* Remove dead code and fix arch filtering.

* Minor bugfix.

* Updates after rebase.

* Trimming code.

* Fix copyright headers.

* Consolidate examples, cut down code.

* Minor fixes.

* Improving python examples.

* Update readmes.

* Remove conv functionality.

* Cleanup following conv removable.
2026-01-22 09:34:33 -08:00

311 lines
11 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 11: JSON-based Kernel Configuration Import
Demonstrates loading kernel configurations from JSON files, similar to tile_engine.
This enables easy customization of kernel sets without modifying code.
Key Features:
- Load tile configs from JSON (compatible with tile_engine format)
- Generate kernel sets from configuration
- Use arch_filter validation on loaded configs
- Export to C++ DECL_KERNEL_SET format
Complexity: ★★★☆☆
Usage:
python3 11_json_import.py
python3 11_json_import.py --config my_kernels.json
python3 11_json_import.py --export-cpp
"""
import sys
import argparse
import json
from pathlib import Path
# Add codegen to path for kernel_config_loader
script_dir = Path(__file__).parent.resolve()
sys.path.insert(0, str(script_dir.parent.parent.parent / "codegen"))
sys.path.insert(0, str(script_dir.parent.parent.parent / "python"))
from kernel_config_loader import ( # noqa: E402
load_kernel_configs,
KernelConfig,
generate_cpp_kernel_set_declaration,
)
from ctypes_utils import ( # noqa: E402
KernelConfig as DispatcherKernelConfig,
setup_gemm_dispatcher,
cleanup_gemm,
reset_for_example,
validate_kernel_config,
)
# Sample JSON configuration (embedded for demonstration)
SAMPLE_JSON_CONFIG = {
"_comment": "Sample kernel configuration for GEMM",
"kernel_set_name": "inference_kernels",
"datatype": {"a": "fp16", "b": "fp16", "c": "fp16", "acc": "fp32"},
"layout": "rcr",
"tile_config": {
"tile_m": {"values": [128, 256]},
"tile_n": {"values": [128, 256]},
"tile_k": {"values": [32]},
"warp_m": {"values": [2]},
"warp_n": {"values": [2]},
"warp_k": {"values": [1]},
"warp_tile_m": {"values": [32]},
"warp_tile_n": {"values": [32]},
"warp_tile_k": {"values": [16]},
},
"trait_config": {
"pipeline": {"values": ["compv4"]},
"scheduler": {"values": ["intrawave"]},
"epilogue": {"values": ["cshuffle"]},
"pad_m": {"values": [False]},
"pad_n": {"values": [False]},
"pad_k": {"values": [False]},
},
"gpu_targets": ["gfx942"],
}
def print_section(title: str):
"""Print a section header"""
print(f"\n{'=' * 70}")
print(f" {title}")
print(f"{'=' * 70}\n")
def convert_to_dispatcher_config(
config: KernelConfig, arch: str = "gfx942"
) -> DispatcherKernelConfig:
"""Convert kernel_config_loader.KernelConfig to dispatcher KernelConfig"""
return DispatcherKernelConfig(
dtype_a=config.dtype_a,
dtype_b=config.dtype_b,
dtype_c=config.dtype_c,
dtype_acc=config.dtype_acc,
tile_m=config.tile.tile_m,
tile_n=config.tile.tile_n,
tile_k=config.tile.tile_k,
wave_m=config.tile.warp_m,
wave_n=config.tile.warp_n,
wave_k=config.tile.warp_k,
warp_m=config.tile.warp_tile_m,
warp_n=config.tile.warp_tile_n,
warp_k=config.tile.warp_tile_k,
pipeline=config.trait.pipeline,
scheduler=config.trait.scheduler,
epilogue=config.trait.epilogue,
pad_m=config.trait.pad_m,
pad_n=config.trait.pad_n,
pad_k=config.trait.pad_k,
gfx_arch=arch,
variant=config.variant,
)
def main():
parser = argparse.ArgumentParser(
description="JSON Kernel Configuration Import Example",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python3 11_json_import.py # Use embedded sample config
python3 11_json_import.py --config my.json # Load from file
python3 11_json_import.py --export-cpp # Generate C++ declarations
python3 11_json_import.py --validate # Validate configs against arch
""",
)
parser.add_argument(
"--config",
type=str,
help="Path to JSON configuration file (uses embedded sample if not provided)",
)
parser.add_argument(
"--export-cpp",
action="store_true",
help="Export kernel set as C++ DECL_KERNEL_SET",
)
parser.add_argument(
"--validate",
action="store_true",
help="Validate all configurations against arch filter",
)
parser.add_argument(
"--arch",
default="gfx942",
help="Target GPU architecture (default: gfx942)",
)
args = parser.parse_args()
reset_for_example()
print_section("Example 11: JSON Kernel Configuration Import")
# =========================================================================
# Step 1: Load configuration from JSON
# =========================================================================
print("Step 1: Load Kernel Configuration from JSON")
print("-" * 50)
if args.config:
config_path = Path(args.config)
if not config_path.exists():
print(f" ERROR: Config file not found: {config_path}")
return 1
print(f" Loading from: {config_path}")
config_set = load_kernel_configs(config_path)
else:
# Use embedded sample config
print(" Using embedded sample configuration")
# Write to temp file and load
temp_path = Path("/tmp/sample_gemm_config.json")
with open(temp_path, "w") as f:
json.dump(SAMPLE_JSON_CONFIG, f, indent=2)
config_set = load_kernel_configs(temp_path)
print(f"\n Kernel Set Name: {config_set.name}")
print(
f" Data Types: A={config_set.dtype_a}, B={config_set.dtype_b}, C={config_set.dtype_c}"
)
print(f" Layout: {config_set.layout}")
print(f" GPU Targets: {config_set.gpu_targets}")
print(f" Total Configurations: {config_set.config_count()}")
# =========================================================================
# Step 2: Display configuration details
# =========================================================================
print("\nStep 2: Configuration Details")
print("-" * 50)
print("\n Tile Configurations:")
print(f" tile_m: {config_set.tile_m_values}")
print(f" tile_n: {config_set.tile_n_values}")
print(f" tile_k: {config_set.tile_k_values}")
print(
f" warp (wave): {config_set.warp_m_values}x{config_set.warp_n_values}x{config_set.warp_k_values}"
)
print(
f" warp_tile: {config_set.warp_tile_m_values}x{config_set.warp_tile_n_values}x{config_set.warp_tile_k_values}"
)
print("\n Trait Configurations:")
print(f" pipeline: {config_set.pipeline_values}")
print(f" scheduler: {config_set.scheduler_values}")
print(f" epilogue: {config_set.epilogue_values}")
print(
f" padding: m={config_set.pad_m_values}, n={config_set.pad_n_values}, k={config_set.pad_k_values}"
)
# =========================================================================
# Step 3: Generate and display kernel names
# =========================================================================
print("\nStep 3: Generated Kernel Names")
print("-" * 50)
configs = list(config_set.generate_configs())
for i, config in enumerate(configs[:5]):
print(f" {i + 1}. {config.kernel_name()}")
if len(configs) > 5:
print(f" ... and {len(configs) - 5} more configurations")
# =========================================================================
# Step 4: Validate against arch filter (optional)
# =========================================================================
if args.validate:
print("\nStep 4: Architecture Validation")
print("-" * 50)
valid_count = 0
invalid_count = 0
for config in configs:
disp_config = convert_to_dispatcher_config(config, args.arch)
result = validate_kernel_config(disp_config)
if result.is_valid:
valid_count += 1
else:
invalid_count += 1
if invalid_count <= 3: # Show first 3 invalid
print(f"\n ✗ Invalid: {config.kernel_name()}")
for error in result.errors:
print(f" Error: {error}")
print("\n Validation Summary:")
print(f" ✓ Valid: {valid_count}")
print(f" ✗ Invalid: {invalid_count}")
print(f" Total: {len(configs)}")
# =========================================================================
# Step 5: Export to C++ (optional)
# =========================================================================
if args.export_cpp:
print("\nStep 5: C++ Export")
print("-" * 50)
print("\n // Generated DECL_KERNEL_SET from JSON config:")
print(" // " + "=" * 56)
cpp_code = generate_cpp_kernel_set_declaration(config_set)
for line in cpp_code.split("\n"):
print(f" {line}")
# =========================================================================
# Step 6: Use first config with dispatcher (demo)
# =========================================================================
print("\nStep 6: Dispatcher Integration Demo")
print("-" * 50)
if configs:
first_config = configs[0]
disp_config = convert_to_dispatcher_config(first_config, args.arch)
print(
f"\n Using first config: {first_config.tile.tile_m}x{first_config.tile.tile_n}x{first_config.tile.tile_k}"
)
setup = setup_gemm_dispatcher(
disp_config, registry_name="json_import", verbose=False
)
if setup.success:
print(" ✓ Dispatcher setup successful")
print(
f" Kernel header: {setup.kernel_header.name if setup.kernel_header else 'N/A'}"
)
else:
print(f" ⚠ Dispatcher setup: {setup.error}")
print(" (This is expected if kernels aren't generated)")
# =========================================================================
# Summary
# =========================================================================
print_section("Summary")
print(" JSON configuration allows easy kernel set customization:")
print(" - Define tile sizes and ranges")
print(" - Specify trait combinations (pipeline, scheduler, etc.)")
print(" - Target multiple GPU architectures")
print(" - Export to C++ DECL_KERNEL_SET for static compilation")
print()
print(" JSON Format (tile_engine compatible):")
print(' {"tile_config": {"tile_m": {"values": [128, 256]}, ...},')
print(' "trait_config": {"pipeline": {"values": ["compv4"]}, ...}}')
print()
print(" Usage:")
print(" config_set = load_kernel_configs('my_kernels.json')")
print(" for config in config_set.generate_configs():")
print(" # Use config for codegen or dispatcher setup")
cleanup_gemm()
return 0
if __name__ == "__main__":
sys.exit(main())