mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 12:41:26 +00:00
Adding dispatcher architecture (#3300)
* WIP POC of dispatcher * Dispatcher python workflow setup. * Dispatcher cleanup and updates. Further dispatcher cleanup and updates. Build fixes Improvements and python to CK example Improvements to readme * Fixes to python paths * Cleaning up code * Improving dispatcher support for different arch Fixing typos * Fix formatting errors * Cleaning up examples * Improving codegeneration * Improving and fixing C++ examples * Adding conv functionality (fwd,bwd,bwdw) and examples. * Fixes based on feedback. * Further fixes based on feedback. * Adding stress test for autogeneration and autocorrection, and fixing preshuffle bug. * Another round of improvements based on feedback. * Trimming out unnecessary code. * Fixing the multi-D implementation. * Using gpu verification for gemms and fixing convolutions tflops calculation. * Fix counter usage issue and arch filtering per ops. * Adding changelog and other fixes. * Improve examples and resolve critical bugs. * Reduce build time for python examples. * Fixing minor bug. * Fix compilation error. * Improve installation instructions for dispatcher. * Add docker based installation instructions for dispatcher. * Fixing arch-based filtering to match tile engine. * Remove dead code and fix arch filtering. * Minor bugfix. * Updates after rebase. * Trimming code. * Fix copyright headers. * Consolidate examples, cut down code. * Minor fixes. * Improving python examples. * Update readmes. * Remove conv functionality. * Cleanup following conv removable.
This commit is contained in:
committed by
GitHub
parent
44f481a45c
commit
9e049a32a1
299
dispatcher/examples/gemm/python/README.md
Normal file
299
dispatcher/examples/gemm/python/README.md
Normal file
@@ -0,0 +1,299 @@
|
||||
# GEMM Python Examples
|
||||
|
||||
CK Tile Dispatcher Python examples for GEMM (General Matrix Multiplication) operations.
|
||||
|
||||
> **Main Documentation**: [Dispatcher README](../../../README.md) | [Examples Overview](../../README.md)
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Build Library
|
||||
|
||||
```bash
|
||||
cd /path/to/composable_kernel/dispatcher
|
||||
mkdir -p build && cd build
|
||||
|
||||
cmake .. \
|
||||
-DCMAKE_PREFIX_PATH=/opt/rocm \
|
||||
-DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
|
||||
-DBUILD_DISPATCHER_EXAMPLES=ON
|
||||
|
||||
# Build Python library (kernels generated automatically)
|
||||
make dispatcher_gemm_lib -j$(nproc)
|
||||
```
|
||||
|
||||
### Run Examples
|
||||
|
||||
```bash
|
||||
cd /path/to/composable_kernel/dispatcher
|
||||
|
||||
python3 examples/gemm/python/01_basic_gemm.py
|
||||
python3 examples/gemm/python/04_validation.py
|
||||
python3 examples/gemm/python/07_stress_test.py
|
||||
python3 examples/gemm/python/08_heuristics.py
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
| Example | Description |
|
||||
|---------|-------------|
|
||||
| [01_basic_gemm.py](01_basic_gemm.py) | Basic GEMM with multi-kernel support |
|
||||
| [02_batch_gemm.py](02_batch_gemm.py) | Batched GEMM operations |
|
||||
| [03_benchmark.py](03_benchmark.py) | Performance benchmarking |
|
||||
| [04_validation.py](04_validation.py) | CPU reference validation |
|
||||
| [05_numpy_integration.py](05_numpy_integration.py) | NumPy array integration |
|
||||
| [06_json_export.py](06_json_export.py) | Registry JSON export |
|
||||
| [07_stress_test.py](07_stress_test.py) | Multi-kernel stress testing |
|
||||
| [08_heuristics.py](08_heuristics.py) | Heuristic-based kernel selection |
|
||||
| [09_multi_registry.py](09_multi_registry.py) | Multiple registries |
|
||||
| [10_advanced_benchmark.py](10_advanced_benchmark.py) | Advanced benchmark with full control |
|
||||
| [11_json_import.py](11_json_import.py) | Import kernels from JSON |
|
||||
|
||||
## Example Details
|
||||
|
||||
### 01_basic_gemm.py - Basic GEMM
|
||||
Demonstrates the Python API with multi-kernel support:
|
||||
|
||||
```python
|
||||
from ctypes_utils import KernelConfig, setup_gemm_dispatcher, print_kernel_config_table
|
||||
|
||||
# Define multiple kernel configurations
|
||||
kernels = [
|
||||
KernelConfig(
|
||||
tile_m=128, tile_n=128, tile_k=32,
|
||||
wave_m=2, wave_n=2, wave_k=1,
|
||||
warp_tile_m=32, warp_tile_n=32, warp_tile_k=16,
|
||||
pipeline="compv3", scheduler="intrawave"
|
||||
),
|
||||
KernelConfig(
|
||||
tile_m=256, tile_n=256, tile_k=32,
|
||||
wave_m=2, wave_n=2, wave_k=1,
|
||||
warp_tile_m=32, warp_tile_n=32, warp_tile_k=16,
|
||||
pipeline="compv4", scheduler="intrawave"
|
||||
),
|
||||
]
|
||||
|
||||
# Display configurations
|
||||
print_kernel_config_table(kernels)
|
||||
|
||||
# Set up dispatcher with all kernels
|
||||
lib, dispatcher, registry = setup_gemm_dispatcher(kernels)
|
||||
|
||||
# Run GEMM
|
||||
elapsed_ms = run_gemm(lib, M, N, K, ...)
|
||||
```
|
||||
|
||||
### 02_batch_gemm.py - Batch GEMM
|
||||
Batched matrix multiplication:
|
||||
- Multiple independent GEMM operations
|
||||
- Batch dimension handling
|
||||
|
||||
### 03_benchmark.py - Benchmarking
|
||||
Performance measurement:
|
||||
- GPU timing
|
||||
- TFLOPS calculation
|
||||
- Multiple iterations
|
||||
|
||||
### 04_validation.py - Validation
|
||||
Correctness verification:
|
||||
- NumPy reference implementation
|
||||
- Tolerance-based validation
|
||||
- Error reporting
|
||||
|
||||
### 05_numpy_integration.py - NumPy Integration
|
||||
Seamless NumPy integration:
|
||||
- NumPy arrays to GPU buffers
|
||||
- Results back to NumPy
|
||||
- Automatic type conversion
|
||||
|
||||
### 06_json_export.py - JSON Export
|
||||
Registry serialization for tool integration:
|
||||
- Export kernel configurations
|
||||
- Machine-readable format
|
||||
|
||||
### 07_stress_test.py - Stress Testing
|
||||
Comprehensive multi-kernel stress testing:
|
||||
|
||||
```python
|
||||
from ctypes_utils import KernelConfig, setup_gemm_dispatcher, print_kernel_config_table
|
||||
|
||||
# Define 48 unique kernel configurations
|
||||
kernels = [
|
||||
KernelConfig(tile_m=128, tile_n=128, tile_k=32, pipeline="compv3", ...),
|
||||
KernelConfig(tile_m=256, tile_n=256, tile_k=32, pipeline="compv4", ...),
|
||||
KernelConfig(tile_m=128, tile_n=256, tile_k=64, pipeline="compv3", ...),
|
||||
# ... many more configurations
|
||||
]
|
||||
|
||||
# Test each kernel
|
||||
for i, kernel in enumerate(kernels):
|
||||
lib, dispatcher, registry = setup_gemm_dispatcher([kernel])
|
||||
result = run_and_validate(lib, M, N, K, seed=42 + i) # Different seed per kernel
|
||||
print(f"Kernel {i}: {result.max_err:.6e} {'PASS' if result.passed else 'FAIL'}")
|
||||
```
|
||||
|
||||
**Features:**
|
||||
- 48 unique kernel configurations
|
||||
- Various tile sizes, pipelines, and schedulers
|
||||
- Per-kernel validation with unique random seeds
|
||||
- Performance reporting
|
||||
|
||||
### 08_heuristics.py - Heuristic Selection
|
||||
Custom kernel selection based on problem characteristics:
|
||||
|
||||
```python
|
||||
# Define kernel pools for different strategies
|
||||
SMALL_KERNELS = [KernelConfig(tile_m=64, tile_n=64, ...), ...]
|
||||
LARGE_KERNELS = [KernelConfig(tile_m=256, tile_n=256, ...), ...]
|
||||
COMPUTE_KERNELS = [KernelConfig(pipeline="compv4", ...), ...]
|
||||
MEMORY_KERNELS = [KernelConfig(pipeline="compv3", ...), ...]
|
||||
|
||||
# Size-based heuristic
|
||||
def size_based_heuristic(M, N, K):
|
||||
if M * N < 512 * 512:
|
||||
return SMALL_KERNELS
|
||||
else:
|
||||
return LARGE_KERNELS
|
||||
|
||||
# Strategy-based selection
|
||||
def compute_strategy():
|
||||
return COMPUTE_KERNELS # Optimized for compute-bound problems
|
||||
|
||||
def memory_strategy():
|
||||
return MEMORY_KERNELS # Optimized for memory-bound problems
|
||||
|
||||
# Test different strategies
|
||||
for strategy in [size_based_heuristic, compute_strategy, memory_strategy]:
|
||||
kernels = strategy(M, N, K)
|
||||
lib, dispatcher, registry = setup_gemm_dispatcher(kernels)
|
||||
elapsed_ms = run_gemm(lib, M, N, K, ...)
|
||||
```
|
||||
|
||||
**Features:**
|
||||
- 24 kernel configurations across 6 categories
|
||||
- Size-based heuristic (small vs large)
|
||||
- Optimization strategies (compute, memory, latency)
|
||||
- Performance comparison across strategies
|
||||
|
||||
### 09_multi_registry.py - Multiple Registries
|
||||
Separate registries for different workloads:
|
||||
- Compute-optimized registry
|
||||
- Latency-optimized registry
|
||||
- Dynamic registry selection
|
||||
|
||||
### 10_advanced_benchmark.py - Advanced Benchmark
|
||||
Full control over benchmark parameters:
|
||||
- Warmup iterations
|
||||
- Benchmark iterations
|
||||
- Statistical analysis
|
||||
|
||||
### 11_json_import.py - JSON Import
|
||||
Import kernel configurations from JSON:
|
||||
- External configuration files
|
||||
- Dynamic kernel loading
|
||||
|
||||
## Utility Module: ctypes_utils.py
|
||||
|
||||
```python
|
||||
from ctypes_utils import (
|
||||
KernelConfig, # Single kernel configuration
|
||||
setup_gemm_dispatcher, # Set up dispatcher with kernels
|
||||
print_kernel_config_table, # Display kernel configurations
|
||||
Dispatcher, # High-level dispatcher
|
||||
Registry, # Kernel registry
|
||||
Validator, # Validation utilities
|
||||
)
|
||||
```
|
||||
|
||||
### KernelConfig
|
||||
|
||||
```python
|
||||
config = KernelConfig(
|
||||
# Tile sizes
|
||||
tile_m=256, tile_n=256, tile_k=32,
|
||||
# Wave configuration
|
||||
wave_m=2, wave_n=2, wave_k=1,
|
||||
# Warp tile sizes
|
||||
warp_tile_m=32, warp_tile_n=32, warp_tile_k=16,
|
||||
# Pipeline and scheduler
|
||||
pipeline="compv4", # "compv3" or "compv4"
|
||||
scheduler="intrawave", # "intrawave" or "interwave"
|
||||
# Optional
|
||||
epilogue="default",
|
||||
padding=True,
|
||||
double_buffer=True,
|
||||
)
|
||||
```
|
||||
|
||||
### setup_gemm_dispatcher
|
||||
|
||||
```python
|
||||
# Single kernel
|
||||
lib, dispatcher, registry = setup_gemm_dispatcher(config)
|
||||
|
||||
# Multiple kernels
|
||||
lib, dispatcher, registry = setup_gemm_dispatcher([config1, config2, ...])
|
||||
|
||||
# With auto-rebuild
|
||||
lib, dispatcher, registry = setup_gemm_dispatcher(config, auto_rebuild=True)
|
||||
```
|
||||
|
||||
### print_kernel_config_table
|
||||
|
||||
```python
|
||||
kernels = [config1, config2, config3]
|
||||
print_kernel_config_table(kernels)
|
||||
# Output:
|
||||
# +----+-------+-------+-------+--------+-----------+
|
||||
# | # | Tile | Wave | Warp | Pipe | Scheduler |
|
||||
# +----+-------+-------+-------+--------+-----------+
|
||||
# | 1 | 128x128x32 | 2x2x1 | 32x32x16 | compv3 | intrawave |
|
||||
# | 2 | 256x256x32 | 2x2x1 | 32x32x16 | compv4 | intrawave |
|
||||
# | 3 | 128x256x64 | 2x2x1 | 32x32x16 | compv3 | interwave |
|
||||
# +----+-------+-------+-------+--------+-----------+
|
||||
```
|
||||
|
||||
### GPU Memory Management
|
||||
|
||||
```python
|
||||
import ctypes
|
||||
import numpy as np
|
||||
|
||||
# Load HIP library
|
||||
hip = ctypes.CDLL("libamdhip64.so")
|
||||
|
||||
# Allocate GPU memory
|
||||
gpu_ptr = ctypes.c_void_p()
|
||||
hip.hipMalloc(ctypes.byref(gpu_ptr), size_in_bytes)
|
||||
|
||||
# Copy to GPU (1 = hipMemcpyHostToDevice)
|
||||
hip.hipMemcpy(gpu_ptr, host_array.ctypes.data, size, 1)
|
||||
|
||||
# Copy back (2 = hipMemcpyDeviceToHost)
|
||||
hip.hipMemcpy(host_array.ctypes.data, gpu_ptr, size, 2)
|
||||
|
||||
# Free
|
||||
hip.hipFree(gpu_ptr)
|
||||
```
|
||||
|
||||
## Performance Testing
|
||||
|
||||
Test compilation performance with different kernel counts:
|
||||
|
||||
```bash
|
||||
# Test with 10 kernels (~15s compile time)
|
||||
python3 01_basic_gemm.py --num-kernels 10
|
||||
|
||||
# Test with 20 kernels (~25s compile time)
|
||||
python3 01_basic_gemm.py --num-kernels 20
|
||||
|
||||
# Test with 48 kernels (~50s compile time)
|
||||
python3 01_basic_gemm.py --num-kernels 48
|
||||
```
|
||||
|
||||
Compilation time scales roughly linearly with kernel count.
|
||||
|
||||
## Related Documentation
|
||||
|
||||
- [C++ GEMM Examples](../cpp/README.md)
|
||||
- [Python Conv Examples](../../conv/python/README.md)
|
||||
- [Main Dispatcher README](../../../README.md)
|
||||
Reference in New Issue
Block a user