From 920acd2c12d31afcb79ef18f49690e6e1c90cdc1 Mon Sep 17 00:00:00 2001 From: Vidyasagar Ananthan Date: Thu, 9 Apr 2026 17:39:35 +0000 Subject: [PATCH] [rocm-libraries] ROCm/rocm-libraries#5168 (commit 8b5afcb) [CK] [CK_Tile] Add GroupConv to Kernel Dispatcher ## Motivation This PR adds CK Tile group convolution (forward, backward-data, backward-weight) support to the kernel dispatcher, matching and unifying with the existing dispatcher GEMM infrastructure in architecture and usability. The dispatcher provides a unified kernel dispatch system with both C++ and Python frontends, and until now only supported GEMM operations. This PR enables framework integrators to use the same declarative kernel workflow for convolutions as they do for GEMM: declare kernels, build a registry JIT, select kernels within the registry at runtime, and dispatch to GPU. Future PRs will include runtime kernel selection heuristics for autotuning of kernel parameters based on (problem, hardware arch). ## Technical Details Grouped convolution support has been added to the CK Tile Dispatcher with generated_conv_backend.hpp enabling dispatcher.run(in, wei, out, problem) for all 6 conv variants (fwd/bwdd/bwdw x 2D/3D), runtime heuristic kernel selection, and GroupedConvKernelKey with full ConvConfigBase fields. Python side adds parallel JIT via registry.build(max_workers) and heuristic registry.select(). Includes 7 C++ and 6 Python examples covering all directions with CPU reference validation, and shared infrastructure improvements (BaseRegistry CRTP, structured exceptions). As a sanity check, JIT compile times for a single kernel remains the same and for multiple kernels there is better parallelism: Kernels | 1 worker | 8 workers 1 | 7.7 s | 7.7 s 2 | 15.9 s | 8.2 s 4 | 33.4 s | 9.7 s 6 | 52.3 s | 10.2 s ## Test Plan 145 ephemeral unit tests have been added to test basic functionality. All 30 examples/integration tests run end-to-end on gfx950 (MI350): 7 C++ conv, 7 C++ GEMM, 6 Python conv, 10 Python GEMM. CPU reference validation for forward, backward-data, and backward-weight (2D) in both C++ and Python examples pass. ## Test Result 30 examples pass. Peak performance: 132 TFLOPS (Batch-32 forward 56x56), 53 TFLOPS (pointwise 1x1). CPU reference accuracy: max_abs_diff < 0.002 for all directions (fp16 vs fp32 reference). ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --- dispatcher/README.md | 141 +- dispatcher/bindings/README.md | 24 +- dispatcher/bindings/ctypes/CMakeLists.txt | 4 +- .../bindings/ctypes/conv_bwdw_ctypes_lib.cpp | 4 +- .../bindings/ctypes/conv_ctypes_lib.cpp | 643 +++--- dispatcher/codegen/ADDING_NEW_GPU.md | 16 +- dispatcher/codegen/README.md | 47 +- dispatcher/codegen/codegen_common.py | 350 ++++ .../generate_dispatcher_registration.py | 8 +- .../codegen/generate_kernel_wrappers.py | 8 +- dispatcher/codegen/kernel_config_loader.py | 32 +- dispatcher/codegen/unified_gemm_codegen.py | 236 +-- .../codegen/unified_grouped_conv_codegen.py | 1757 ++++++++++++++++ dispatcher/examples/CMakeLists.txt | 70 +- dispatcher/examples/README.md | 45 +- .../examples/gemm/cpp/02_multi_size.cpp | 20 +- .../examples/gemm/cpp/07_gfx950_minimal.cpp | 191 ++ dispatcher/examples/gemm/cpp/README.md | 18 +- .../examples/gemm/python/01_basic_gemm.py | 291 +-- .../examples/gemm/python/02_batch_gemm.py | 35 +- .../examples/gemm/python/03_benchmark.py | 37 +- .../examples/gemm/python/04_validation.py | 34 +- .../gemm/python/05_numpy_integration.py | 6 +- .../examples/gemm/python/06_json_export.py | 6 +- .../examples/gemm/python/07_stress_test.py | 6 +- .../examples/gemm/python/08_heuristics.py | 6 +- .../examples/gemm/python/09_multi_registry.py | 6 +- .../gemm/python/10_advanced_benchmark.py | 7 +- .../examples/gemm/python/11_json_import.py | 16 +- dispatcher/examples/gemm/python/README.md | 2 +- .../cpp/01_basic_grouped_conv.cpp | 203 ++ .../grouped_conv/cpp/02_all_directions.cpp | 216 ++ .../cpp/03_benchmark_validation.cpp | 263 +++ .../grouped_conv/cpp/04_registry_json.cpp | 154 ++ .../examples/grouped_conv/cpp/05_bwd_data.cpp | 183 ++ .../grouped_conv/cpp/06_bwd_weight.cpp | 188 ++ .../cpp/07_multi_tile_benchmark.cpp | 226 +++ .../python/01_basic_grouped_conv.py | 271 +++ .../grouped_conv/python/02_forward.py | 222 ++ .../grouped_conv/python/03_bwd_data.py | 214 ++ .../grouped_conv/python/04_bwd_weight.py | 224 ++ .../grouped_conv/python/05_benchmark.py | 318 +++ .../grouped_conv/python/06_registry_json.py | 274 +++ dispatcher/include/ck_tile/dispatcher.hpp | 20 +- .../include/ck_tile/dispatcher/README.md | 96 +- .../backends/generated_conv_backend.hpp | 152 ++ .../ck_tile/dispatcher/base_registry.hpp | 199 ++ .../include/ck_tile/dispatcher/dispatcher.hpp | 22 +- .../ck_tile/dispatcher/dispatcher_error.hpp | 28 + .../ck_tile/dispatcher/dispatcher_log.hpp | 55 + .../dispatcher/grouped_conv_config.hpp | 588 ++++++ .../dispatcher/grouped_conv_kernel_decl.hpp | 537 +++++ .../dispatcher/grouped_conv_problem.hpp | 255 +++ .../dispatcher/grouped_conv_registry.hpp | 614 ++++++ .../ck_tile/dispatcher/grouped_conv_utils.hpp | 324 +++ .../include/ck_tile/dispatcher/problem.hpp | 8 +- .../include/ck_tile/dispatcher/registry.hpp | 105 +- .../include/ck_tile/dispatcher_conv.hpp | 18 + .../include/ck_tile/dispatcher_gemm.hpp | 22 + dispatcher/python/CMakeLists.txt | 2 +- dispatcher/python/README.md | 48 +- dispatcher/python/ctypes_utils.py | 715 ++++++- dispatcher/python/dispatcher_common.py | 372 ++++ dispatcher/python/grouped_conv_utils.py | 1806 +++++++++++++++++ dispatcher/scripts/compile_gemm_examples.py | 87 +- .../scripts/compile_grouped_conv_examples.py | 882 ++++++++ dispatcher/scripts/example_kernel_builder.py | 396 ++-- .../scripts/generate_conv_dispatch_header.py | 107 + dispatcher/scripts/parallel_kernel_builder.py | 2 +- dispatcher/scripts/stress_test_autocorrect.py | 10 +- dispatcher/src/dispatcher.cpp | 13 +- dispatcher/src/registry.cpp | 181 +- dispatcher/tests/CMakeLists.txt | 4 + dispatcher/tests/test_autocorrect.py | 8 +- dispatcher/tests/test_codegen_common.py | 244 +++ dispatcher/tests/test_dispatcher_common.py | 243 +++ dispatcher/tests/test_examples_integration.py | 175 +- dispatcher/tests/test_grouped_conv_codegen.py | 589 ++++++ dispatcher/tests/test_grouped_conv_config.cpp | 112 + .../tests/test_grouped_conv_kernel_decl.cpp | 141 ++ .../tests/test_grouped_conv_problem.cpp | 245 +++ .../tests/test_grouped_conv_registry.cpp | 230 +++ dispatcher/tests/test_grouped_conv_utils.py | 349 ++++ dispatcher/tests/test_problem_extended.cpp | 8 +- .../tests/test_real_kernel_multi_size.cpp | 2 +- .../tests/test_real_kernel_performance.cpp | 2 +- 86 files changed, 15538 insertions(+), 1500 deletions(-) create mode 100644 dispatcher/codegen/codegen_common.py create mode 100644 dispatcher/codegen/unified_grouped_conv_codegen.py create mode 100644 dispatcher/examples/gemm/cpp/07_gfx950_minimal.cpp create mode 100644 dispatcher/examples/grouped_conv/cpp/01_basic_grouped_conv.cpp create mode 100644 dispatcher/examples/grouped_conv/cpp/02_all_directions.cpp create mode 100644 dispatcher/examples/grouped_conv/cpp/03_benchmark_validation.cpp create mode 100644 dispatcher/examples/grouped_conv/cpp/04_registry_json.cpp create mode 100644 dispatcher/examples/grouped_conv/cpp/05_bwd_data.cpp create mode 100644 dispatcher/examples/grouped_conv/cpp/06_bwd_weight.cpp create mode 100644 dispatcher/examples/grouped_conv/cpp/07_multi_tile_benchmark.cpp create mode 100644 dispatcher/examples/grouped_conv/python/01_basic_grouped_conv.py create mode 100644 dispatcher/examples/grouped_conv/python/02_forward.py create mode 100644 dispatcher/examples/grouped_conv/python/03_bwd_data.py create mode 100644 dispatcher/examples/grouped_conv/python/04_bwd_weight.py create mode 100644 dispatcher/examples/grouped_conv/python/05_benchmark.py create mode 100644 dispatcher/examples/grouped_conv/python/06_registry_json.py create mode 100644 dispatcher/include/ck_tile/dispatcher/backends/generated_conv_backend.hpp create mode 100644 dispatcher/include/ck_tile/dispatcher/base_registry.hpp create mode 100644 dispatcher/include/ck_tile/dispatcher/dispatcher_error.hpp create mode 100644 dispatcher/include/ck_tile/dispatcher/dispatcher_log.hpp create mode 100644 dispatcher/include/ck_tile/dispatcher/grouped_conv_config.hpp create mode 100644 dispatcher/include/ck_tile/dispatcher/grouped_conv_kernel_decl.hpp create mode 100644 dispatcher/include/ck_tile/dispatcher/grouped_conv_problem.hpp create mode 100644 dispatcher/include/ck_tile/dispatcher/grouped_conv_registry.hpp create mode 100644 dispatcher/include/ck_tile/dispatcher/grouped_conv_utils.hpp create mode 100644 dispatcher/include/ck_tile/dispatcher_conv.hpp create mode 100644 dispatcher/include/ck_tile/dispatcher_gemm.hpp create mode 100644 dispatcher/python/dispatcher_common.py create mode 100644 dispatcher/python/grouped_conv_utils.py create mode 100644 dispatcher/scripts/compile_grouped_conv_examples.py create mode 100644 dispatcher/scripts/generate_conv_dispatch_header.py create mode 100644 dispatcher/tests/test_codegen_common.py create mode 100644 dispatcher/tests/test_dispatcher_common.py create mode 100644 dispatcher/tests/test_grouped_conv_codegen.py create mode 100644 dispatcher/tests/test_grouped_conv_config.cpp create mode 100644 dispatcher/tests/test_grouped_conv_kernel_decl.cpp create mode 100644 dispatcher/tests/test_grouped_conv_problem.cpp create mode 100644 dispatcher/tests/test_grouped_conv_registry.cpp create mode 100644 dispatcher/tests/test_grouped_conv_utils.py diff --git a/dispatcher/README.md b/dispatcher/README.md index 1395285d60..dc864f7c62 100644 --- a/dispatcher/README.md +++ b/dispatcher/README.md @@ -1,6 +1,6 @@ # CK Tile Dispatcher -A unified kernel dispatch system for AMD GPUs with C++ and Python frontends. +A unified kernel dispatch system for AMD GPUs with C++ and Python frontends, supporting GEMM and Grouped Convolution operations. **Validated Platform:** AMD Instinct MI300 series (gfx942) @@ -342,8 +342,8 @@ ls examples/libdispatcher_gemm_lib.so | `CMAKE_PREFIX_PATH` | - | ROCm installation path | | `CMAKE_CXX_COMPILER` | - | Path to hipcc compiler | -⚠️ **Important:** Always use `-DCMAKE_BUILD_TYPE=Release` for benchmarking. Debug builds are slower. -⚠️ **Important:** Note that the current system provides single GPU target support for architecture-based kernel filtering, please do not use multiple GPU targets at a time (if necessary, please compile into different build directories). +WARNING: **Important:** Always use `-DCMAKE_BUILD_TYPE=Release` for benchmarking. Debug builds are slower. +WARNING: **Important:** Note that the current system provides single GPU target support for architecture-based kernel filtering, please do not use multiple GPU targets at a time (if necessary, please compile into different build directories). --- @@ -363,6 +363,15 @@ cd build/examples ./gemm_04_heuristics # Heuristic kernel selection ./gemm_05_json_export # Registry JSON export ./gemm_06_multi_registry # Multiple registries + +# Grouped Convolution Examples +./grouped_conv_01_basic # Declaration patterns + GPU execution +./grouped_conv_02_all_dirs # Forward/BwdData/BwdWeight with GPU +./grouped_conv_03_bench_val # Benchmark + CPU reference validation +./grouped_conv_04_registry_json # Heuristic selection + JSON export +./grouped_conv_05_bwd_data # Backward data + CPU validation +./grouped_conv_06_bwd_weight # Backward weight + CPU validation +./grouped_conv_07_benchmark # Multi-tile ResNet benchmark ``` ### Python Examples @@ -375,8 +384,16 @@ cd /path/to/composable_kernel/dispatcher # GEMM Examples python3 examples/gemm/python/01_basic_gemm.py # Basic multi-kernel GEMM python3 examples/gemm/python/04_validation.py # CPU reference validation -python3 examples/gemm/python/07_stress_test.py # Stress test (48 kernels) +python3 examples/gemm/python/07_stress_test.py # Stress test python3 examples/gemm/python/08_heuristics.py # Heuristic selection + +# Grouped Convolution Examples +python3 examples/grouped_conv/python/01_basic_grouped_conv.py # Config patterns + registry + GPU +python3 examples/grouped_conv/python/02_forward.py # Forward 2D/3D + CPU ref +python3 examples/grouped_conv/python/03_bwd_data.py # Backward data + CPU ref +python3 examples/grouped_conv/python/04_bwd_weight.py # Backward weight + CPU ref +python3 examples/grouped_conv/python/05_benchmark.py # Multi-problem benchmark +python3 examples/grouped_conv/python/06_registry_json.py # Heuristic selection + JSON ``` ### Example Output @@ -647,7 +664,7 @@ lib = DispatcherLib.load("/absolute/path/to/libdispatcher_gemm_lib.so") ### Data Flow ``` -KernelConfig → Registry → Dispatcher → GPU Execution +KernelConfig -> Registry -> Dispatcher -> GPU Execution ``` 1. **KernelConfig**: Defines kernel parameters (tile sizes, data types, layouts) @@ -843,31 +860,49 @@ make -j$(nproc) ``` dispatcher/ -├── README.md # This file -├── CMakeLists.txt # Build configuration -│ -├── include/ck_tile/dispatcher/ # C++ headers -│ ├── dispatcher.hpp # GEMM dispatcher -│ ├── registry.hpp # Kernel registry -│ └── kernel_key.hpp # Kernel configuration -│ -├── src/ # C++ implementation -│ -├── codegen/ # Kernel generation -│ ├── unified_gemm_codegen.py # GEMM kernel generator -│ └── arch_specs.json # GPU specifications -│ -├── bindings/ctypes/ # Python ctypes interface -│ └── gemm_ctypes_lib.cpp # GEMM Python library -│ -├── examples/ # Examples -│ └── gemm/ -│ ├── cpp/ # C++ GEMM examples (01-06) -│ └── python/ # Python GEMM examples (01-11) -│ -├── scripts/ # Build scripts -│ -└── tests/ # Unit tests +|---- README.md # This file +|---- CMakeLists.txt # Build configuration +| +|---- include/ck_tile/dispatcher/ # C++ headers +| |---- dispatcher.hpp # Main dispatcher include +| |---- registry.hpp # GEMM kernel registry +| |---- kernel_key.hpp # Kernel configuration +| |---- grouped_conv_config.hpp # Grouped conv configuration +| |---- grouped_conv_problem.hpp # Grouped conv problem (with builder) +| |---- grouped_conv_kernel_decl.hpp # Grouped conv kernel declarations +| |---- grouped_conv_registry.hpp # Grouped conv registry (thread-safe) +| +---- grouped_conv_utils.hpp # Grouped conv utilities +| +|---- src/ # C++ implementation +| +|---- codegen/ # Kernel generation +| |---- codegen_common.py # Shared: TileConfig, TraitConfigBase, type mappings +| |---- unified_gemm_codegen.py # GEMM kernel generator +| |---- unified_grouped_conv_codegen.py # Grouped conv kernel generator +| +---- arch_specs.json # GPU specifications +| +|---- python/ # Python utilities +| |---- dispatcher_common.py # Shared: paths, validation, Colors, phased output +| |---- ctypes_utils.py # GEMM ctypes utilities +| +---- grouped_conv_utils.py # Grouped conv utilities +| +|---- scripts/ # Build scripts +| |---- compile_gemm_examples.py # GEMM build script +| +---- compile_grouped_conv_examples.py # Grouped conv build script +| +|---- bindings/ctypes/ # Python ctypes interface +| |---- gemm_ctypes_lib.cpp # GEMM Python library +| +---- conv_ctypes_lib.cpp # Grouped conv Python library +| +|---- examples/ # Examples +| |---- gemm/ +| | |---- cpp/ # C++ GEMM examples (01-07) +| | +---- python/ # Python GEMM examples (01-11) +| +---- grouped_conv/ +| |---- cpp/ # C++ Grouped Conv examples (01-07) +| +---- python/ # Python Grouped Conv examples (01-06) +| ++---- tests/ # Unit tests (C++ and Python) ``` --- @@ -879,17 +914,49 @@ dispatcher/ | GEMM C++ | [examples/gemm/cpp/README.md](examples/gemm/cpp/README.md) | | GEMM Python | [examples/gemm/python/README.md](examples/gemm/python/README.md) | | Codegen | [codegen/README.md](codegen/README.md) | +| Python Utils | [python/README.md](python/README.md) | +| C++ Headers | [include/ck_tile/dispatcher/README.md](include/ck_tile/dispatcher/README.md) | --- -## Archived Content +## Grouped Convolution Support -Convolution examples and utilities have been archived to `ck-2/conv_archive/dispatcher/`: -- `examples/conv/cpp/` - 11 C++ convolution examples -- `examples/conv/python/` - 14 Python convolution examples -- `codegen/unified_conv_codegen.py` - Conv kernel generator -- `include/ck_tile/dispatcher/conv_*.hpp` - Conv headers -- `python/conv_utils.py` - Conv Python utilities +Grouped convolution is fully supported alongside GEMM, with shared infrastructure to eliminate duplication. + +### Python + +```bash +# Generate grouped conv kernels +python3 codegen/unified_grouped_conv_codegen.py \ + --output-dir build/generated_kernels \ + --datatype fp16 --variant forward --ndim-spatial 2 + +# Build grouped conv examples +python3 scripts/compile_grouped_conv_examples.py examples/grouped_conv/cpp/01_basic_grouped_conv.cpp +``` + +### Key Files + +| Component | File | +|-----------|------| +| C++ Headers | `include/ck_tile/dispatcher/grouped_conv_*.hpp` | +| Python Codegen | `codegen/unified_grouped_conv_codegen.py` | +| Python Utils | `python/grouped_conv_utils.py` | +| Build Script | `scripts/compile_grouped_conv_examples.py` | +| Shared Codegen | `codegen/codegen_common.py` | +| Shared Utils | `python/dispatcher_common.py` | + +### Variants + +- **Forward** (`grouped_conv_fwd`) - Standard grouped convolution +- **Backward Data** (`grouped_conv_bwd_data`) - Gradient w.r.t. input +- **Backward Weight** (`grouped_conv_bwd_weight`) - Gradient w.r.t. weights + +### Shared Infrastructure + +GEMM and grouped convolution share common code to avoid duplication: +- `codegen/codegen_common.py` - TileConfig, TraitConfigBase, type mappings, parallel generation, arch-aware expansion +- `python/dispatcher_common.py` - Path helpers, validation, auto-correction, Colors, phased output --- diff --git a/dispatcher/bindings/README.md b/dispatcher/bindings/README.md index 7cda21f6ec..04029d32a9 100644 --- a/dispatcher/bindings/README.md +++ b/dispatcher/bindings/README.md @@ -6,13 +6,13 @@ This directory contains language bindings for the CK Tile Dispatcher. ``` bindings/ -├── ctypes/ # Python ctypes bindings (C API) -│ ├── gemm_ctypes_lib.cpp # GEMM dispatcher C API -│ ├── conv_ctypes_lib.cpp # Convolution dispatcher C API (fwd + bwd_data) -│ ├── conv_bwdw_ctypes_lib.cpp # Convolution backward weight C API -│ ├── gpu_helper.cpp # CLI helper for Python -│ └── CMakeLists.txt -└── README.md +|---- ctypes/ # Python ctypes bindings (C API) +| |---- gemm_ctypes_lib.cpp # GEMM dispatcher C API +| |---- conv_ctypes_lib.cpp # Grouped conv dispatcher C API (fwd + bwd_data) +| |---- conv_bwdw_ctypes_lib.cpp # Grouped conv backward weight C API (separate library) +| |---- gpu_helper.cpp # CLI helper for Python +| +---- CMakeLists.txt ++---- README.md ``` ## ctypes Bindings @@ -65,7 +65,7 @@ lib.dispatcher_cleanup() | `dispatcher_export_registry_json()` | Export registry as JSON | | `dispatcher_cleanup()` | Release resources | -### Convolution API +### Grouped Convolution API | Function | Description | |----------|-------------| @@ -105,5 +105,11 @@ Output is JSON for easy parsing: See the examples that use these bindings: - **GEMM**: `dispatcher/examples/gemm/python/` -- **Conv**: `dispatcher/examples/conv/python/` + +### Grouped Convolution + +Grouped convolution C++ headers and Python utilities are in: +- **C++ Headers**: `dispatcher/include/ck_tile/dispatcher/grouped_conv_*.hpp` +- **Python Utils**: `dispatcher/python/grouped_conv_utils.py` +- **Build Script**: `dispatcher/scripts/compile_grouped_conv_examples.py` diff --git a/dispatcher/bindings/ctypes/CMakeLists.txt b/dispatcher/bindings/ctypes/CMakeLists.txt index 804e5e9bd7..18314017f2 100644 --- a/dispatcher/bindings/ctypes/CMakeLists.txt +++ b/dispatcher/bindings/ctypes/CMakeLists.txt @@ -78,7 +78,7 @@ endif() # Look for forward kernels file(GLOB CONV_FWD_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/conv_fwd_*.hpp") # Look for backward data kernels -file(GLOB CONV_BWDD_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/conv_bwdd_*.hpp") +file(GLOB CONV_BWDD_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/conv_bwd_data_*.hpp") # Fallback: any conv kernel (for backwards compatibility) file(GLOB CONV_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/conv_*.hpp") @@ -112,7 +112,7 @@ endif() # Add backward data kernel if available if(CONV_BWDD_KERNEL_HEADERS) list(GET CONV_BWDD_KERNEL_HEADERS 0 CONV_BWDD_KERNEL_HEADER) - message(STATUS "Found Conv BWD_DATA kernel for ctypes lib: ${CONV_BWDD_KERNEL_HEADER}") + message(STATUS "Found Conv BWD_DATA kernel for ctypes lib: ${CONV_BWD_DATA_KERNEL_HEADER}") target_compile_options(dispatcher_conv_lib PRIVATE -include ${CONV_BWDD_KERNEL_HEADER}) target_compile_definitions(dispatcher_conv_lib PRIVATE CONV_BWD_DATA_AVAILABLE) endif() diff --git a/dispatcher/bindings/ctypes/conv_bwdw_ctypes_lib.cpp b/dispatcher/bindings/ctypes/conv_bwdw_ctypes_lib.cpp index 09e058f80f..96b4aa3462 100644 --- a/dispatcher/bindings/ctypes/conv_bwdw_ctypes_lib.cpp +++ b/dispatcher/bindings/ctypes/conv_bwdw_ctypes_lib.cpp @@ -53,6 +53,7 @@ struct ConvBwdwProblemC int stride_d, stride_h, stride_w; int pad_d, pad_h, pad_w; int dilation_d, dilation_h, dilation_w; + int split_k; }; // ============================================================================= @@ -108,8 +109,7 @@ static float run_bwd_weight_impl(const void* input_ptr, grad_weight_ptr, // wei_ptr = grad_weight (output) {}, // ds_ptr grad_output_ptr, // out_ptr = grad_output - 1 // k_batch - ); + (prob->split_k > 1) ? prob->split_k : 1); ck_tile::stream_config stream_cfg{static_cast(stream), true, 1, 3, 10}; diff --git a/dispatcher/bindings/ctypes/conv_ctypes_lib.cpp b/dispatcher/bindings/ctypes/conv_ctypes_lib.cpp index d3c64621a7..002219c82e 100644 --- a/dispatcher/bindings/ctypes/conv_ctypes_lib.cpp +++ b/dispatcher/bindings/ctypes/conv_ctypes_lib.cpp @@ -1,128 +1,46 @@ // Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT - -/** - * Convolution Dispatcher ctypes Library - * - * Provides C API for Python ctypes integration. - * Supports forward convolution. Backward operations require additional headers. - * - * REQUIRED: Forward kernel header must be force-included via -include flag. - * OPTIONAL: Backward kernels can be added with CONV_BWD_DATA_AVAILABLE/CONV_BWD_WEIGHT_AVAILABLE - * - * Usage from Python: - * lib = ctypes.CDLL("libdispatcher_conv.so") - * lib.conv_dispatcher_init() - * lib.conv_dispatcher_run(...) - */ +// +// Multi-kernel grouped convolution dispatcher for Python ctypes. +// +// Supports: forward / backward-data / backward-weight x 2D / 3D +// +// The dispatch header (conv_python_dispatch.hpp) is force-included via +// -include and brings in ALL compiled kernels with these aliases: +// +// 2D launchers (from include_all headers): +// SelectedConvKernelLauncher (forward 2D) +// SelectedConvBwdDataLauncher (backward-data 2D) +// SelectedConvBwdWeightLauncher (backward-weight 2D) +// +// 3D launchers (from dispatch header): +// ConvFwd3dLauncher (forward 3D) +// ConvBwdData3dLauncher (backward-data 3D) +// ConvBwdWeight3dLauncher (backward-weight 3D) +// +// Usage from Python: +// lib = ctypes.CDLL("libdispatcher_conv_lib.so") +// lib.conv_dispatcher_init() +// lib.conv_dispatcher_run(input, weight, output, &problem, stream) #include -#include -#include +#include #include -#include "ck_tile/dispatcher/conv_utils.hpp" #include "ck_tile/core.hpp" #include "ck_tile/host.hpp" -using namespace ck_tile::dispatcher; - -// Global state (using shared_ptr for safe memory management) -static std::shared_ptr g_registry = nullptr; -static std::shared_ptr g_dispatcher = nullptr; -static std::vector g_kernels; - extern "C" { -// ============================================================================= -// Initialization -// ============================================================================= - -int conv_dispatcher_init() +// ========================================================================= +// Problem definition (matches Python ctypes struct exactly) +// ========================================================================= +enum ConvDirection { - if(g_registry) - return 0; // Already initialized - - g_registry = std::make_shared(); - g_dispatcher = std::make_shared(g_registry.get()); - - // Register kernel configurations using simple ConvKernelSet - // (actual kernel launch uses the force-included SelectedConvKernelLauncher) - using namespace ck_tile::dispatcher::conv_decl; - - // Forward kernels (required - must be force-included) - // Must match: conv_fwd_fp16_nhwgc_2d_compv4_cshuffle_intrawave_128x128x64_2x2x1_32x32x16_dsb - ConvKernelSet fwd_set; - fwd_set.add(ConvSignature().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2), - ConvAlgorithm() - .tile(128, 128, 64) // tile_m x tile_n x tile_k - .wave(2, 2, 1) - .warp(32, 32, 16) - .pipeline("compv4") - .scheduler("intrawave"), - "gfx942"); - g_registry->register_set(fwd_set, ConvRegistry::Priority::High); - -#ifdef CONV_BWD_DATA_AVAILABLE - // Backward data kernels - // Must match: conv_bwdd_fp16_nhwgc_2d_compv3_cshuffle_intrawave_128x128x64_2x2x1_32x32x16 - ConvKernelSet bwd_data_set; - bwd_data_set.add(ConvSignature().dtype("fp16").layout("nhwgc").conv_type("bwd_data").dims(2), - ConvAlgorithm() - .tile(128, 128, 64) // tile_m x tile_n x tile_k - .wave(2, 2, 1) - .warp(32, 32, 16) - .pipeline("compv3") - .scheduler("intrawave"), - "gfx942"); - g_registry->register_set(bwd_data_set, ConvRegistry::Priority::High); -#endif - - return 0; -} - -int conv_dispatcher_cleanup() -{ - // shared_ptr automatically handles cleanup when reset - g_dispatcher.reset(); - g_registry.reset(); - g_kernels.clear(); - return 0; -} - -// ============================================================================= -// Registry Management -// ============================================================================= - -int conv_dispatcher_get_kernel_count() -{ - if(!g_registry) - return 0; - return static_cast(g_registry->size()); -} - -int conv_dispatcher_get_kernel_name(int index, char* buffer, int buffer_size) -{ - if(index < 0 || !buffer || buffer_size <= 0) - return -1; - - if(!g_registry) - return -1; - - // Use registry to get kernel names (they are registered with full names) - const auto& kernels = g_registry->all_kernels(); - if(static_cast(index) >= kernels.size()) - return -1; - - const auto* kernel = kernels[index]; - std::strncpy(buffer, kernel->name().c_str(), buffer_size - 1); - buffer[buffer_size - 1] = '\0'; - return 0; -} - -// ============================================================================= -// Problem Definition -// ============================================================================= + CONV_FORWARD = 0, + CONV_BWD_DATA = 1, + CONV_BWD_WEIGHT = 2 +}; struct ConvProblemC { @@ -132,267 +50,33 @@ struct ConvProblemC int stride_d, stride_h, stride_w; int pad_d, pad_h, pad_w; int dilation_d, dilation_h, dilation_w; - int direction; // 0=forward, 1=bwd_data, 2=bwd_weight + int direction; + int split_k; }; -// ============================================================================= -// Kernel Selection -// ============================================================================= +// ========================================================================= +// Initialization / lifecycle +// ========================================================================= +int conv_dispatcher_init() { return 0; } +int conv_dispatcher_cleanup() { return 0; } -int conv_dispatcher_is_supported(const ConvProblemC* prob) -{ - if(!g_registry || !prob) - return 0; - - ConvProblem problem; - problem.N = prob->N; - problem.G = prob->G; - problem.C = prob->C; - problem.K = prob->K; - problem.input_spatial = {prob->input_d, prob->input_h, prob->input_w}; - problem.filter_spatial = {prob->filter_z, prob->filter_y, prob->filter_x}; - problem.stride = {prob->stride_d, prob->stride_h, prob->stride_w}; - problem.padding = {prob->pad_d, prob->pad_h, prob->pad_w}; - problem.dilation = {prob->dilation_d, prob->dilation_h, prob->dilation_w}; - problem.op = static_cast(prob->direction); - problem.compute_output_size(); - - const auto* kernel = g_dispatcher->select(problem); - return kernel ? 1 : 0; -} - -int conv_dispatcher_select_kernel(const ConvProblemC* prob, char* kernel_name, int buffer_size) -{ - if(!g_registry || !prob || !kernel_name || buffer_size <= 0) - return -1; - - ConvProblem problem; - problem.N = prob->N; - problem.G = prob->G; - problem.C = prob->C; - problem.K = prob->K; - problem.input_spatial = {prob->input_d, prob->input_h, prob->input_w}; - problem.filter_spatial = {prob->filter_z, prob->filter_y, prob->filter_x}; - problem.stride = {prob->stride_d, prob->stride_h, prob->stride_w}; - problem.padding = {prob->pad_d, prob->pad_h, prob->pad_w}; - problem.dilation = {prob->dilation_d, prob->dilation_h, prob->dilation_w}; - problem.op = static_cast(prob->direction); - problem.compute_output_size(); - - const auto* kernel = g_dispatcher->select(problem); - if(!kernel) - return -1; - - std::strncpy(kernel_name, kernel->name().c_str(), buffer_size - 1); - kernel_name[buffer_size - 1] = '\0'; - - return 0; -} - -// ============================================================================= -// Convolution Execution -// ============================================================================= - -// Helper to build ConvParam -static ck_tile::conv::ConvParam build_conv_param(const ConvProblemC* prob) -{ - // Determine if this is 2D or 3D convolution - const bool is_3d = (prob->input_d > 1 || prob->filter_z > 1); - - if(is_3d) - { - // 3D convolution: use all spatial dimensions - return ck_tile::conv::ConvParam{3, - prob->G, - prob->N, - prob->K, - prob->C, - {prob->filter_z, prob->filter_y, prob->filter_x}, - {prob->input_d, prob->input_h, prob->input_w}, - {prob->stride_d, prob->stride_h, prob->stride_w}, - {prob->dilation_d, prob->dilation_h, prob->dilation_w}, - {prob->pad_d, prob->pad_h, prob->pad_w}, - {prob->pad_d, prob->pad_h, prob->pad_w}}; - } - else - { - // 2D convolution: only use H, W dimensions - return ck_tile::conv::ConvParam{2, - prob->G, - prob->N, - prob->K, - prob->C, - {prob->filter_y, prob->filter_x}, - {prob->input_h, prob->input_w}, - {prob->stride_h, prob->stride_w}, - {prob->dilation_h, prob->dilation_w}, - {prob->pad_h, prob->pad_w}, - {prob->pad_h, prob->pad_w}}; - } -} - -// Forward convolution (required - kernel header must be force-included) -static float run_forward(const void* input_ptr, - const void* weight_ptr, - void* output_ptr, - const ConvProblemC* prob, - void* stream) -{ - auto conv_param = build_conv_param(prob); - - ck_tile::GroupedConvFwdHostArgs<> args(conv_param, input_ptr, weight_ptr, {}, output_ptr, 1); - - ck_tile::stream_config stream_cfg{static_cast(stream), true, 1, 3, 10}; - - // SelectedConvKernelLauncher is defined in the force-included forward kernel header - return SelectedConvKernelLauncher::launch(args, stream_cfg); -} - -#ifdef CONV_BWD_DATA_AVAILABLE -// Backward data convolution (optional) -// Computes: grad_input = conv_bwd_data(weight, grad_output) -// -// Parameters: -// grad_output_ptr: dY - gradient from next layer (const, read-only INPUT) -// weight_ptr: W - frozen weights (const, read-only INPUT) -// grad_input_ptr: dX - gradient for input (writable, OUTPUT) -static float run_bwd_data(const void* grad_output_ptr, - const void* weight_ptr, - void* grad_input_ptr, - const ConvProblemC* prob, - void* stream) -{ - auto conv_param = build_conv_param(prob); - - // CK Tile API uses tensor POSITION names (from forward pass), not data flow: - // in_ptr = input tensor position = grad_input_ptr (dX, OUTPUT of bwd_data) - // wei_ptr = weight tensor = weight_ptr (W, const) - // out_ptr = output tensor position = grad_output_ptr (dY, INPUT to bwd_data) - ck_tile::GroupedConvBwdDataHostArgs args( - conv_param, grad_input_ptr, weight_ptr, {}, grad_output_ptr, 1); - - ck_tile::stream_config stream_cfg{static_cast(stream), true, 1, 3, 10}; - - return SelectedConvBwdDataLauncher::launch(args, stream_cfg); -} -#endif - -#ifdef CONV_BWD_WEIGHT_AVAILABLE -// Backward weight convolution (optional) -// Parameters: -// input_ptr: original forward input X (const, read-only) -// grad_output_ptr: gradient from next layer dY (const, read-only) -// grad_weight_ptr: gradient of weights dW (writable, OUTPUT) -static float run_bwd_weight(const void* input_ptr, - const void* grad_output_ptr, - void* grad_weight_ptr, - const ConvProblemC* prob, - void* stream) -{ - auto conv_param = build_conv_param(prob); - - // GroupedConvBwdWeightHostArgs constructor order: - // (param, in=X, wei=dW (output), ds, out=dY (input), k_batch) - // Note: wei_ptr is the OUTPUT (grad_weight), out_ptr is the INPUT (grad_output) - ck_tile::GroupedConvBwdWeightHostArgs args( - conv_param, input_ptr, grad_weight_ptr, {}, grad_output_ptr, 1); - - ck_tile::stream_config stream_cfg{static_cast(stream), true, 1, 3, 10}; - - return SelectedConvBwdWeightLauncher::launch(args, stream_cfg); -} -#endif - -/** - * @brief Execute convolution based on direction specified in prob - * - * Parameter mapping varies by direction: - * Forward (direction=0): - * input_ptr = X (input tensor) - * weight_ptr = W (weight tensor) - * output_ptr = Y (output buffer) - * - * Backward Data (direction=1): - * input_ptr = dY (grad_output - gradient from next layer) - * weight_ptr = W (weight tensor, frozen) - * output_ptr = dX (grad_input buffer) - * - * Backward Weight (direction=2): - * input_ptr = X (forward input tensor) - * weight_ptr = dY (grad_output - gradient from next layer) - * output_ptr = dW (grad_weight buffer) - */ -float conv_dispatcher_run(const void* input_ptr, - const void* weight_ptr, - void* output_ptr, - const ConvProblemC* prob, - void* stream) -{ - // Validate all required pointers before kernel launch - if(!g_dispatcher || !prob) - return -1.0f; - if(!input_ptr || !weight_ptr || !output_ptr) - return -1.0f; // Null data pointer would cause kernel crash - - // Build problem for kernel selection - ConvProblem problem; - problem.N = prob->N; - problem.G = prob->G; - problem.C = prob->C; - problem.K = prob->K; - problem.input_spatial = {prob->input_d, prob->input_h, prob->input_w}; - problem.filter_spatial = {prob->filter_z, prob->filter_y, prob->filter_x}; - problem.stride = {prob->stride_d, prob->stride_h, prob->stride_w}; - problem.padding = {prob->pad_d, prob->pad_h, prob->pad_w}; - problem.dilation = {prob->dilation_d, prob->dilation_h, prob->dilation_w}; - problem.op = static_cast(prob->direction); - problem.compute_output_size(); - - // Select kernel - const auto* kernel = g_dispatcher->select(problem); - if(!kernel) - return -1.0f; - - // Dispatch based on direction - switch(prob->direction) - { - case 0: // Forward (always available) - return run_forward(input_ptr, weight_ptr, output_ptr, prob, stream); - -#ifdef CONV_BWD_DATA_AVAILABLE - case 1: // Backward data - // Convention: caller passes (grad_output, weight, grad_input_buffer) - // in the (input_ptr, weight_ptr, output_ptr) slots respectively. - // run_bwd_data expects: (grad_output, weight, grad_input) - return run_bwd_data(input_ptr, weight_ptr, output_ptr, prob, stream); -#endif - -#ifdef CONV_BWD_WEIGHT_AVAILABLE - case 2: // Backward weight - // Convention: caller passes (input, grad_output, grad_weight_buffer) - // in the (input_ptr, weight_ptr, output_ptr) slots respectively. - // run_bwd_weight expects: (input, grad_output, grad_weight) - return run_bwd_weight(input_ptr, weight_ptr, output_ptr, prob, stream); -#endif - - default: return -1.0f; - } -} - -// ============================================================================= -// Info -// ============================================================================= - -const char* conv_dispatcher_version() { return "1.0.0"; } +// ========================================================================= +// Library info +// ========================================================================= +const char* conv_dispatcher_version() { return "2.0.0"; } int conv_dispatcher_has_kernels() { - return 1; // Forward kernel is required +#if defined(CONV_FWD_2D_AVAILABLE) || defined(CONV_FWD_3D_AVAILABLE) + return 1; +#else + return 0; +#endif } int conv_dispatcher_has_bwd_data() { -#ifdef CONV_BWD_DATA_AVAILABLE +#if defined(CONV_BWD_DATA_2D_AVAILABLE) || defined(CONV_BWD_DATA_3D_AVAILABLE) return 1; #else return 0; @@ -401,11 +85,240 @@ int conv_dispatcher_has_bwd_data() int conv_dispatcher_has_bwd_weight() { -#ifdef CONV_BWD_WEIGHT_AVAILABLE +#if defined(CONV_BWD_WEIGHT_2D_AVAILABLE) || defined(CONV_BWD_WEIGHT_3D_AVAILABLE) return 1; #else return 0; #endif } +int conv_dispatcher_get_kernel_count() +{ + return CONV_KERNEL_COUNT; // defined in conv_python_dispatch.hpp +} + +int conv_dispatcher_get_kernel_name(int index, char* buffer, int buffer_size) +{ + if(!buffer || buffer_size <= 0 || index < 0 || index >= CONV_KERNEL_COUNT) + return -1; + std::strncpy(buffer, CONV_KERNEL_NAMES[index], buffer_size - 1); + buffer[buffer_size - 1] = '\0'; + return 0; +} + +// ========================================================================= +// Support query +// ========================================================================= +bool conv_dispatcher_is_supported(const ConvProblemC* prob) +{ + if(!prob) + return false; + const bool is_3d = (prob->input_d > 1 || prob->filter_z > 1); + switch(prob->direction) + { + case CONV_FORWARD: +#if defined(CONV_FWD_3D_AVAILABLE) + if(is_3d) + return true; +#endif +#if defined(CONV_FWD_2D_AVAILABLE) + if(!is_3d) + return true; +#endif + return false; + case CONV_BWD_DATA: +#if defined(CONV_BWD_DATA_3D_AVAILABLE) + if(is_3d) + return true; +#endif +#if defined(CONV_BWD_DATA_2D_AVAILABLE) + if(!is_3d) + return true; +#endif + return false; + case CONV_BWD_WEIGHT: +#if defined(CONV_BWD_WEIGHT_3D_AVAILABLE) + if(is_3d) + return true; +#endif +#if defined(CONV_BWD_WEIGHT_2D_AVAILABLE) + if(!is_3d) + return true; +#endif + return false; + default: return false; + } +} + +// ========================================================================= +// ConvParam builders +// ========================================================================= +static ck_tile::conv::ConvParam make_param_2d(const ConvProblemC* p) +{ + return ck_tile::conv::ConvParam{2, + p->G, + p->N, + p->K, + p->C, + {p->filter_y, p->filter_x}, + {p->input_h, p->input_w}, + {p->stride_h, p->stride_w}, + {p->dilation_h, p->dilation_w}, + {p->pad_h, p->pad_w}, + {p->pad_h, p->pad_w}}; +} + +static ck_tile::conv::ConvParam make_param_3d(const ConvProblemC* p) +{ + return ck_tile::conv::ConvParam{3, + p->G, + p->N, + p->K, + p->C, + {p->filter_z, p->filter_y, p->filter_x}, + {p->input_d, p->input_h, p->input_w}, + {p->stride_d, p->stride_h, p->stride_w}, + {p->dilation_d, p->dilation_h, p->dilation_w}, + {p->pad_d, p->pad_h, p->pad_w}, + {p->pad_d, p->pad_h, p->pad_w}}; +} + +// ========================================================================= +// Kernel launch helpers +// ========================================================================= + +#ifdef CONV_FWD_2D_AVAILABLE +static float +launch_fwd_2d(const void* in, const void* wei, void* out, const ConvProblemC* p, hipStream_t stream) +{ + auto param = make_param_2d(p); + ck_tile::GroupedConvFwdHostArgs<> args(param, in, wei, {}, out, 1); + ck_tile::stream_config sc{stream, true, 1, 3, 10}; + return SelectedConvKernelLauncher::launch(args, sc); +} +#endif + +#ifdef CONV_FWD_3D_AVAILABLE +static float +launch_fwd_3d(const void* in, const void* wei, void* out, const ConvProblemC* p, hipStream_t stream) +{ + auto param = make_param_3d(p); + ck_tile::GroupedConvFwdHostArgs<> args(param, in, wei, {}, out, 1); + ck_tile::stream_config sc{stream, true, 1, 3, 10}; + return ConvFwd3dLauncher::launch(args, sc); +} +#endif + +#ifdef CONV_BWD_DATA_2D_AVAILABLE +static float launch_bwd_data_2d( + const void* dy, const void* wei, void* dx, const ConvProblemC* p, hipStream_t stream) +{ + auto param = make_param_2d(p); + ck_tile::GroupedConvBwdDataHostArgs args(param, dx, wei, {}, dy, 1); + ck_tile::stream_config sc{stream, true, 1, 3, 10}; + return SelectedConvBwdDataLauncher::launch(args, sc); +} +#endif + +#ifdef CONV_BWD_DATA_3D_AVAILABLE +static float launch_bwd_data_3d( + const void* dy, const void* wei, void* dx, const ConvProblemC* p, hipStream_t stream) +{ + auto param = make_param_3d(p); + ck_tile::GroupedConvBwdDataHostArgs args(param, dx, wei, {}, dy, 1); + ck_tile::stream_config sc{stream, true, 1, 3, 10}; + return ConvBwdData3dLauncher::launch(args, sc); +} +#endif + +#ifdef CONV_BWD_WEIGHT_2D_AVAILABLE +static float launch_bwd_weight_2d( + const void* x, const void* dy, void* dw, const ConvProblemC* p, hipStream_t stream) +{ + auto param = make_param_2d(p); + const int k_batch = (p->split_k > 1) ? p->split_k : 1; + ck_tile::GroupedConvBwdWeightHostArgs args(param, x, dw, {}, dy, k_batch); + ck_tile::stream_config sc{stream, true, 1, 3, 10}; + return SelectedConvBwdWeightLauncher::launch(args, sc); +} +#endif + +#ifdef CONV_BWD_WEIGHT_3D_AVAILABLE +static float launch_bwd_weight_3d( + const void* x, const void* dy, void* dw, const ConvProblemC* p, hipStream_t stream) +{ + auto param = make_param_3d(p); + const int k_batch = (p->split_k > 1) ? p->split_k : 1; + ck_tile::GroupedConvBwdWeightHostArgs args(param, x, dw, {}, dy, k_batch); + ck_tile::stream_config sc{stream, true, 1, 3, 10}; + return ConvBwdWeight3dLauncher::launch(args, sc); +} +#endif + +// ========================================================================= +// Main dispatch +// +// direction=0 (forward): a=X(input), b=W(weight), c=Y(output) +// direction=1 (bwd_data): a=dY(grad_out), b=W(weight), c=dX(grad_in) +// direction=2 (bwd_weight): a=X(input), b=dY(grad_out), c=dW(grad_wei) +// ========================================================================= +float conv_dispatcher_run( + const void* a_ptr, const void* b_ptr, void* c_ptr, const ConvProblemC* prob, void* stream) +{ + if(!prob || !a_ptr || !b_ptr || !c_ptr) + return -1.0f; + + const bool is_3d = (prob->input_d > 1 || prob->filter_z > 1); + hipStream_t hip_stream = static_cast(stream); + + try + { + switch(prob->direction) + { + case CONV_FORWARD: +#ifdef CONV_FWD_3D_AVAILABLE + if(is_3d) + return launch_fwd_3d(a_ptr, b_ptr, c_ptr, prob, hip_stream); +#endif +#ifdef CONV_FWD_2D_AVAILABLE + if(!is_3d) + return launch_fwd_2d(a_ptr, b_ptr, c_ptr, prob, hip_stream); +#endif + return -2.0f; + + case CONV_BWD_DATA: +#ifdef CONV_BWD_DATA_3D_AVAILABLE + if(is_3d) + return launch_bwd_data_3d(a_ptr, b_ptr, c_ptr, prob, hip_stream); +#endif +#ifdef CONV_BWD_DATA_2D_AVAILABLE + if(!is_3d) + return launch_bwd_data_2d(a_ptr, b_ptr, c_ptr, prob, hip_stream); +#endif + return -2.0f; + + case CONV_BWD_WEIGHT: +#ifdef CONV_BWD_WEIGHT_3D_AVAILABLE + if(is_3d) + return launch_bwd_weight_3d(a_ptr, b_ptr, c_ptr, prob, hip_stream); +#endif +#ifdef CONV_BWD_WEIGHT_2D_AVAILABLE + if(!is_3d) + return launch_bwd_weight_2d(a_ptr, b_ptr, c_ptr, prob, hip_stream); +#endif + return -2.0f; + + default: return -1.0f; + } + } + catch(const std::exception&) + { + return -3.0f; // Kernel rejected args (e.g. unsupported tile/channel combo) + } + catch(...) + { + return -3.0f; + } +} + } // extern "C" diff --git a/dispatcher/codegen/ADDING_NEW_GPU.md b/dispatcher/codegen/ADDING_NEW_GPU.md index 0bd2966a85..664b59b6b1 100644 --- a/dispatcher/codegen/ADDING_NEW_GPU.md +++ b/dispatcher/codegen/ADDING_NEW_GPU.md @@ -9,8 +9,8 @@ Guide for adding support for a new AMD GPU architecture to the CK Tile Dispatche The dispatcher uses `arch_specs.json` as the **single source of truth** for GPU specifications: ``` -arch_specs.json → generate_arch_specs.py → arch_specs_generated.py (Python) - → arch_specs_generated.hpp (C++) +arch_specs.json -> generate_arch_specs.py -> arch_specs_generated.py (Python) + -> arch_specs_generated.hpp (C++) ``` ## Quick Start @@ -175,14 +175,14 @@ for error in result.errors: ``` codegen/ -├── arch_specs.json # Single source of truth (EDIT THIS) -├── generate_arch_specs.py # Generator script -├── arch_specs_generated.py # Generated Python module -└── ADDING_NEW_GPU.md # This file +|---- arch_specs.json # Single source of truth (EDIT THIS) +|---- generate_arch_specs.py # Generator script +|---- arch_specs_generated.py # Generated Python module ++---- ADDING_NEW_GPU.md # This file include/ck_tile/dispatcher/ -├── arch_specs_generated.hpp # Generated C++ header -└── arch_filter.hpp # C++ filter +|---- arch_specs_generated.hpp # Generated C++ header ++---- arch_filter.hpp # C++ filter ``` ## Best Practices diff --git a/dispatcher/codegen/README.md b/dispatcher/codegen/README.md index 2d753924f5..40a9b7b8c1 100644 --- a/dispatcher/codegen/README.md +++ b/dispatcher/codegen/README.md @@ -1,11 +1,22 @@ -# CK Tile GEMM Unified Code Generator +# CK Tile Unified Code Generators -Single source of truth for all GEMM kernel generation. +Single source of truth for GEMM and Grouped Convolution kernel generation. > **See also:** [Main Dispatcher README](../README.md) for installation and core concepts. +## Shared Infrastructure + +Both GEMM and Grouped Conv generators share common code via `codegen_common.py`: +- `TileConfig` - Dataclass for tile dimensions +- `TraitConfigBase` - Base for kernel trait configurations with arch-aware validation +- `CommonTypeMappings` - Dtype-to-C++ type mappings +- `parallel_generate()` - Parallel kernel generation with per-kernel progress logging +- Arch-aware expansion helpers (`valid_wave_configs`, `valid_warp_configs`, etc.) + ## Quick Start +### GEMM + ```bash cd dispatcher/codegen @@ -22,6 +33,25 @@ python3 unified_gemm_codegen.py \ --variants standard preshuffle multi_d ``` +### Grouped Convolution + +```bash +cd dispatcher/codegen + +# Generate forward FP16 grouped conv kernels +python3 unified_grouped_conv_codegen.py \ + --output-dir ../build/generated_kernels \ + --datatype fp16 \ + --variant forward \ + --ndim-spatial 2 + +# Generate backward data kernels +python3 unified_grouped_conv_codegen.py \ + --output-dir ../build/generated_kernels \ + --variant backward_data \ + --ndim-spatial 2 +``` + ## Using from Python ```python @@ -58,13 +88,13 @@ results = codegen.generate_all() ## Variants ### Standard -Basic GEMM: `C = A × B` +Basic GEMM: `C = A x B` ### PreShuffle Optimized weight access with LDS pre-shuffling. Best for large matrices. ### Multi-D -Element-wise fusion: `C = op(A × B + D0 + D1 + ...)` +Element-wise fusion: `C = op(A x B + D0 + D1 + ...)` Supported ops: `PassThrough`, `MultiDAdd`, `Relu`, `Gelu`, `Sigmoid`, `Tanh` @@ -72,10 +102,11 @@ Supported ops: `PassThrough`, `MultiDAdd`, `Relu`, `Gelu`, `Sigmoid`, `Tanh` ``` generated_kernels/ -├── gemm_fp16_rcr_compv4_..._128x128x32_....hpp -├── gemm_fp16_rcr_compv4_..._preshuffle.hpp -├── gemm_fp16_rcr_compv4_..._multid_Relu_d1.hpp -└── ... +|---- gemm_fp16_rcr_compv4_..._128x128x32_....hpp # GEMM kernels +|---- gemm_fp16_rcr_compv4_..._preshuffle.hpp +|---- gemm_fp16_rcr_compv4_..._multid_Relu_d1.hpp +|---- grouped_conv_fwd_fp16_nhwgc_..._128x128x32_....hpp # Grouped conv kernels ++---- ... ``` ## Configuration Files diff --git a/dispatcher/codegen/codegen_common.py b/dispatcher/codegen/codegen_common.py new file mode 100644 index 0000000000..4e9e8de1b3 --- /dev/null +++ b/dispatcher/codegen/codegen_common.py @@ -0,0 +1,350 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Shared codegen infrastructure for GEMM and grouped convolution code generators. + +Extracted from unified_gemm_codegen.py + arch-aware expansion helpers from conv. +Both unified_gemm_codegen.py and unified_grouped_conv_codegen.py import from here +to eliminate duplication. +""" + +import logging +import concurrent.futures +from dataclasses import dataclass +from typing import ( + Callable, + ClassVar, + Dict, + FrozenSet, + List, + Optional, + Sequence, + Tuple, + TypeVar, +) + +log = logging.getLogger(__name__) + +T = TypeVar("T") +R = TypeVar("R") + +ANY_INT = -1 + + +# ============================================================================ +# Tile and Trait Configuration (shared between GEMM and Conv) +# ============================================================================ + + +@dataclass +class TileConfig: + """Tile configuration parameters shared by GEMM and grouped conv.""" + + tile_m: int + tile_n: int + tile_k: int + warp_m: int + warp_n: int + warp_k: int + warp_tile_m: int + warp_tile_n: int + warp_tile_k: int + + def is_valid(self) -> bool: + if self.tile_m <= 0 or self.tile_n <= 0 or self.tile_k <= 0: + return False + return ( + self.tile_m % (self.warp_m * self.warp_tile_m) == 0 + and self.tile_n % (self.warp_n * self.warp_tile_n) == 0 + and self.tile_k % (self.warp_k * self.warp_tile_k) == 0 + ) + + +@dataclass +class TraitConfigBase: + """ + Base kernel trait configuration shared by GEMM and grouped conv. + + GEMM extends this with ``persistent``; grouped conv extends with + ``double_smem_buffer`` and ``num_groups_to_merge``. + """ + + pipeline: str # mem, compv3, compv4, compv5, ... + epilogue: str # cshuffle, default + scheduler: str # intrawave, interwave + pad_m: bool + pad_n: bool + pad_k: bool + + # Unsupported (pipeline, epilogue, scheduler) combinations. + # Only 'mem' and 'basic_v1' pipelines support interwave; all compute + # pipelines (compv3/v4/v5/v6/async) only support intrawave. + _UNSUPPORTED: ClassVar[FrozenSet] = frozenset( + { + ("compv3", "cshuffle", "interwave"), + ("compv3", "default", "interwave"), + ("compv4", "cshuffle", "interwave"), + ("compv4", "default", "interwave"), + ("compv5", "cshuffle", "interwave"), + ("compv5", "default", "interwave"), + ("compv6", "cshuffle", "interwave"), + ("compv6", "default", "interwave"), + ("comp_async", "cshuffle", "interwave"), + ("comp_async", "default", "interwave"), + ("basic_async_v1", "cshuffle", "interwave"), + ("basic_async_v1", "default", "interwave"), + } + ) + + def is_valid(self) -> bool: + return (self.pipeline, self.epilogue, self.scheduler) not in self._UNSUPPORTED + + +# ============================================================================ +# Type Mappings (centralized for both GEMM and conv codegen) +# ============================================================================ + + +class CommonTypeMappings: + """Centralized type mappings shared by GEMM and grouped conv codegen.""" + + DTYPE_TO_CK = { + "fp16": "fp16_t", + "bf16": "bf16_t", + "fp32": "float", + "fp8": "fp8_t", + "bf8": "bf8_t", + "int8": "int8_t", + } + + DTYPE_TO_CK_QUALIFIED = { + "fp16": "ck_tile::fp16_t", + "bf16": "ck_tile::bf16_t", + "fp32": "float", + "fp8": "ck_tile::fp8_t", + "bf8": "ck_tile::bf8_t", + "int8": "int8_t", + } + + DTYPE_TO_DISPATCHER = { + "fp16": "DataType::FP16", + "bf16": "DataType::BF16", + "fp32": "DataType::FP32", + "fp8": "DataType::FP8", + "bf8": "DataType::BF8", + "int8": "DataType::INT8", + } + + # GEMM-specific layout mappings ("r"/"c" for row/column major). + # Convolution layouts (NHWGC, GKYXC, etc.) are handled by + # unified_grouped_conv_codegen.py via GroupedConvLayout / GroupedConvTypeMappings. + GEMM_LAYOUT_TO_CK = { + "r": "tensor_layout::gemm::RowMajor", + "c": "tensor_layout::gemm::ColumnMajor", + } + LAYOUT_TO_CK = GEMM_LAYOUT_TO_CK # backward compat alias + + GEMM_LAYOUT_TO_DISPATCHER = { + "r": "LayoutTag::RowMajor", + "c": "LayoutTag::ColMajor", + } + LAYOUT_TO_DISPATCHER = GEMM_LAYOUT_TO_DISPATCHER # backward compat alias + + # GEMM-only pipeline mappings (used by unified_gemm_codegen.py). + # Convolution pipelines are in GroupedConvTypeMappings + # (unified_grouped_conv_codegen.py). CK Tile conv supports: + # BASIC_V1, Mem, CompV3, CompV4, CompV5, CompV6, ASYNC_V1, ASYNC_V4. + # The dispatcher currently generates: mem, compv3, compv4. + # preshufflev2 is GEMM-only (weight pre-shuffle for GEMM, not conv). + PIPELINE_TO_CK = { + "mem": "GemmPipelineAgBgCrMem", + "compv3": "GemmPipelineAgBgCrCompV3", + "compv4": "GemmPipelineAgBgCrCompV4", + "compv5": "GemmPipelineAgBgCrCompV5", + "preshufflev2": "WeightPreshufflePipelineAGmemBGmemCRegV2", + } + + PIPELINE_TO_BASE = { + "mem": "BaseGemmPipelineAgBgCrMem", + "compv3": "BaseGemmPipelineAgBgCrCompV3", + "compv4": "BaseGemmPipelineAgBgCrCompV4", + "compv5": "BaseGemmPipelineAgBgCrCompV5", + "preshufflev2": "BaseWeightPreshufflePipelineAGmemBGmemCRegV2", + } + + PIPELINE_TO_DISPATCHER = { + "mem": "Pipeline::Mem", + "compv3": "Pipeline::CompV3", + "compv4": "Pipeline::CompV4", + "compv5": "Pipeline::CompV5", + "preshufflev2": "Pipeline::PreShuffleV2", + } + + SCHEDULER_TO_CK = { + "intrawave": "GemmPipelineScheduler::Intrawave", + "interwave": "GemmPipelineScheduler::Interwave", + "default": "GemmPipelineScheduler::Default", + } + + SCHEDULER_TO_DISPATCHER = { + "intrawave": "Scheduler::Intrawave", + "interwave": "Scheduler::Interwave", + "default": "Scheduler::Auto", + } + + EPILOGUE_TO_DISPATCHER = { + "cshuffle": "Epilogue::CShuffle", + "default": "Epilogue::Default", + } + + @staticmethod + def get_output_dtype(dtype: str) -> str: + """Get output datatype (fp8/bf8 -> fp16).""" + return "fp16" if dtype in ("fp8", "bf8") else dtype + + +# ============================================================================ +# Code Generation Helpers +# ============================================================================ + + +def generate_cpp_compilation_unit(kernel_name: str) -> str: + """Generate a .cpp compilation unit that includes a kernel header. + + This is the standard pattern: one .cpp per kernel that just includes + the generated .hpp header, causing template instantiation. + """ + return ( + f"// Auto-generated compilation unit for {kernel_name}\n" + f'#include "{kernel_name}.hpp"\n' + ) + + +def parallel_generate( + generate_fn: Callable[[T], R], + items: Sequence[T], + parallel: bool = True, +) -> List[R]: + """Run ``generate_fn`` over ``items``, optionally in parallel. + + Logs per-item progress (best-of-conv pattern). + Returns a flat list of results in completion order. + """ + results: List[R] = [] + if not items: + return results + + if parallel and len(items) > 1: + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = {executor.submit(generate_fn, item): item for item in items} + for future in concurrent.futures.as_completed(futures): + result = future.result() + results.append(result) + log.info("Generated: %s", futures[future]) + else: + for item in items: + result = generate_fn(item) + results.append(result) + log.info("Generated: %s", item) + + return results + + +# ============================================================================ +# Arch-Aware Expansion Helpers (adopted from conv kernel_decl.hpp) +# ============================================================================ + +# These load from arch_specs_generated when available, falling back to +# hardcoded defaults that match the most common arch (gfx942). + +_arch_data_cache: Optional[Dict] = None + + +def _get_arch_data() -> Dict: + """Load arch filter data, with caching.""" + global _arch_data_cache + if _arch_data_cache is not None: + return _arch_data_cache + + try: + from arch_specs_generated import ( + WARP_SUPPORTED_COMBINATIONS, + WARP_TILE_SUPPORTED_COMBINATIONS, + TRAIT_UNSUPPORTED_COMBINATIONS, + get_supported_archs, + ) + + _arch_data_cache = { + "warp_combos": WARP_SUPPORTED_COMBINATIONS, + "warp_tile_combos": WARP_TILE_SUPPORTED_COMBINATIONS, + "trait_unsupported": TRAIT_UNSUPPORTED_COMBINATIONS, + "supported_archs": get_supported_archs(), + } + except ImportError: + _arch_data_cache = { + "warp_combos": { + "gfx942": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + "gfx90a": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + }, + "warp_tile_combos": { + "gfx942": {"fp16_fp16_fp32": [[16, 16, 16], [32, 32, 16]]}, + "gfx90a": {"fp16_fp16_fp32": [[16, 16, 16], [32, 32, 16]]}, + }, + "trait_unsupported": { + ("compv3", "cshuffle", "interwave"), + ("compv4", "cshuffle", "interwave"), + }, + "supported_archs": ["gfx90a", "gfx942", "gfx950"], + } + return _arch_data_cache + + +def valid_wave_configs(arch: str) -> List[List[int]]: + """Return valid [wave_m, wave_n, wave_k] combos for *arch*.""" + data = _get_arch_data() + return data["warp_combos"].get(arch, [[2, 2, 1]]) + + +def valid_warp_configs(arch: str, dtype: str) -> List[List[int]]: + """Return valid [warp_tile_m, warp_tile_n, warp_tile_k] combos for *arch*/*dtype*. + + The dtype key is constructed as ``{dtype}_{dtype}_{acc}`` where acc is + fp32 for float types and int32 for int8. + """ + data = _get_arch_data() + acc = "int32" if dtype == "int8" else "fp32" + dtype_key = f"{dtype}_{dtype}_{acc}" + arch_tiles = data["warp_tile_combos"].get(arch, {}) + return arch_tiles.get(dtype_key, [[32, 32, 16]]) + + +def valid_trait_configs() -> List[Tuple[str, str]]: + """Return valid (pipeline, scheduler) pairs. + + Compute pipelines only support intrawave; mem supports both. + """ + return [ + ("compv3", "intrawave"), + ("compv4", "intrawave"), + ("compv5", "intrawave"), + ("mem", "intrawave"), + ("mem", "interwave"), + ] + + +def needs_wave_expansion(config: dict) -> bool: + """True if wave_m or wave_n is a wildcard (ANY_INT = -1).""" + return config.get("wave_m", 2) == ANY_INT or config.get("wave_n", 2) == ANY_INT + + +def needs_warp_expansion(config: dict) -> bool: + """True if warp_m or warp_n is a wildcard (ANY_INT = -1).""" + return config.get("warp_m", 32) == ANY_INT or config.get("warp_n", 32) == ANY_INT + + +def needs_pipeline_expansion(config: dict) -> bool: + """True if pipeline is a wildcard (\"*\").""" + return config.get("pipeline", "compv4") == "*" diff --git a/dispatcher/codegen/generate_dispatcher_registration.py b/dispatcher/codegen/generate_dispatcher_registration.py index 024ec4a7c8..8e8b67376c 100644 --- a/dispatcher/codegen/generate_dispatcher_registration.py +++ b/dispatcher/codegen/generate_dispatcher_registration.py @@ -109,7 +109,7 @@ inline void register_all_kernels() """ output_file.write_text(content) - print(f"✓ Generated registration header: {output_file}") + print(f"OK Generated registration header: {output_file}") def generate_registration_cpp(kernels: List[KernelConfig], output_file: Path): @@ -143,7 +143,7 @@ namespace generated { """ output_file.write_text(content) - print(f"✓ Generated registration implementation: {output_file}") + print(f"OK Generated registration implementation: {output_file}") def generate_kernel_wrapper_header(kernel: KernelConfig, output_dir: Path): @@ -414,8 +414,8 @@ def main(): with open(manifest_output, "w") as f: json.dump(manifest_data, f, indent=2) - print(f"✓ Generated manifest: {manifest_output}") - print("\n✓ Registration code generation complete!") + print(f"OK Generated manifest: {manifest_output}") + print("\nOK Registration code generation complete!") print(f" Total kernels: {len(kernels)}") print(" Output files:") print(f" - {registration_header}") diff --git a/dispatcher/codegen/generate_kernel_wrappers.py b/dispatcher/codegen/generate_kernel_wrappers.py index 53a9bff3ed..e11bd7a0a5 100644 --- a/dispatcher/codegen/generate_kernel_wrappers.py +++ b/dispatcher/codegen/generate_kernel_wrappers.py @@ -17,10 +17,10 @@ Usage: Output structure: build/kernel_wrappers/ - ├── gemm_fp16_rcr_128x128x32.cpp - ├── gemm_fp16_rcr_256x256x64.cpp - ├── conv_fwd_fp16_2d_128x128.cpp - └── ... + |---- gemm_fp16_rcr_128x128x32.cpp + |---- gemm_fp16_rcr_256x256x64.cpp + |---- conv_fwd_fp16_2d_128x128.cpp + +---- ... Each .cpp simply includes its corresponding .hpp and forces symbol emission. """ diff --git a/dispatcher/codegen/kernel_config_loader.py b/dispatcher/codegen/kernel_config_loader.py index 537fc40581..980b4e5fd0 100644 --- a/dispatcher/codegen/kernel_config_loader.py +++ b/dispatcher/codegen/kernel_config_loader.py @@ -359,8 +359,8 @@ class ConvTraitConfig: @dataclass -class ConvKernelConfig: - """Complete convolution kernel configuration""" +class GroupedConvKernelConfig: + """Complete grouped convolution kernel configuration""" tile: ConvTileConfig = field(default_factory=ConvTileConfig) trait: ConvTraitConfig = field(default_factory=ConvTraitConfig) @@ -419,7 +419,11 @@ class ConvKernelConfig: def kernel_name(self) -> str: """Generate kernel name from config""" - variant_map = {"forward": "fwd", "bwd_data": "bwdd", "bwd_weight": "bwdw"} + variant_map = { + "forward": "fwd", + "bwd_data": "bwd_data", + "bwd_weight": "bwd_weight", + } var_str = variant_map.get(self.variant, self.variant) name = f"conv_{var_str}_{self.dtype_input}_{self.ndim}d" @@ -433,11 +437,11 @@ class ConvKernelConfig: @dataclass -class ConvKernelConfigSet: +class GroupedConvKernelConfigSet: """A set of convolution kernel configurations loaded from JSON""" name: str = "default" - configs: List[ConvKernelConfig] = field(default_factory=list) + configs: List[GroupedConvKernelConfig] = field(default_factory=list) # Tile parameter ranges tile_m_values: List[int] = field(default_factory=lambda: [128]) @@ -481,7 +485,7 @@ class ConvKernelConfigSet: layout: str = "nhwgc" gpu_targets: List[str] = field(default_factory=lambda: ["gfx942"]) - def generate_configs(self) -> Iterator[ConvKernelConfig]: + def generate_configs(self) -> Iterator[GroupedConvKernelConfig]: """Generate all kernel configurations (cartesian product)""" # Tile parameters tile_params = itertools.product( @@ -548,7 +552,7 @@ class ConvKernelConfigSet: double_smem_buffer=trait[6], num_groups_to_merge=trait[7], ) - yield ConvKernelConfig( + yield GroupedConvKernelConfig( tile=tile_cfg, trait=trait_cfg, dtype_input=self.dtype_input, @@ -599,7 +603,9 @@ class ConvKernelConfigSet: return tile_count * trait_count * extra_count * len(self.gpu_targets) -def load_conv_kernel_configs(json_path: str | Path) -> ConvKernelConfigSet: +def load_grouped_conv_kernel_configs( + json_path: str | Path, +) -> GroupedConvKernelConfigSet: """ Load convolution kernel configurations from a JSON file. @@ -607,14 +613,14 @@ def load_conv_kernel_configs(json_path: str | Path) -> ConvKernelConfigSet: json_path: Path to JSON configuration file Returns: - ConvKernelConfigSet with all parameter values loaded + GroupedConvKernelConfigSet with all parameter values loaded """ json_path = Path(json_path) with open(json_path) as f: data = json.load(f) - config_set = ConvKernelConfigSet() + config_set = GroupedConvKernelConfigSet() # Name config_set.name = data.get("kernel_set_name", json_path.stem) @@ -680,15 +686,15 @@ def load_conv_kernel_configs(json_path: str | Path) -> ConvKernelConfigSet: def generate_cpp_conv_kernel_set_declaration( - config_set: ConvKernelConfigSet, + config_set: GroupedConvKernelConfigSet, set_name: Optional[str] = None, ) -> str: """ - Generate C++ DECL_CONV_KERNEL_SET code from a ConvKernelConfigSet. + Generate C++ DECL_GROUPED_CONV_KERNEL_SET code from a GroupedConvKernelConfigSet. """ name = set_name or config_set.name - lines = [f"DECL_CONV_KERNEL_SET({name},"] + lines = [f"DECL_GROUPED_CONV_KERNEL_SET({name},"] for config in config_set.generate_configs(): line = f' .add("{config.dtype_input}", "{config.variant}", {config.ndim}, ' diff --git a/dispatcher/codegen/unified_gemm_codegen.py b/dispatcher/codegen/unified_gemm_codegen.py index b0dd961be7..a818cec83e 100755 --- a/dispatcher/codegen/unified_gemm_codegen.py +++ b/dispatcher/codegen/unified_gemm_codegen.py @@ -7,7 +7,7 @@ Unified GEMM Code Generator - Single Source of Truth This is THE unified code generator for all GEMM kernel variants: -- Standard GEMM (C = A × B) +- Standard GEMM (C = A x B) - Preshuffle GEMM (optimized weight access) - Multi-D GEMM (element-wise fusion) @@ -25,6 +25,12 @@ from dataclasses import dataclass, asdict from enum import Enum import concurrent.futures +from codegen_common import ( + TileConfig, + TraitConfigBase, + CommonTypeMappings as TypeMappings, +) + # Import architecture filter for GPU-specific validation try: from arch_filter import ArchFilter, KernelConfig as ArchKernelConfig, OperatorType @@ -194,62 +200,14 @@ class GemmVariant(Enum): MULTI_D = "multi_d" -@dataclass -class TileConfig: - """Tile configuration parameters""" - - tile_m: int - tile_n: int - tile_k: int - warp_m: int - warp_n: int - warp_k: int - warp_tile_m: int - warp_tile_n: int - warp_tile_k: int - - def is_valid(self) -> bool: - """Validate tile configuration""" - return ( - self.tile_m % (self.warp_m * self.warp_tile_m) == 0 - and self.tile_n % (self.warp_n * self.warp_tile_n) == 0 - and self.tile_k % (self.warp_k * self.warp_tile_k) == 0 - and self.tile_m > 0 - and self.tile_n > 0 - and self.tile_k > 0 - ) +# TileConfig imported from codegen_common @dataclass -class TraitConfig: - """Kernel trait configuration""" +class TraitConfig(TraitConfigBase): + """GEMM-specific trait configuration extending TraitConfigBase with persistent mode.""" - pipeline: str # mem, compv3, compv4 - epilogue: str # default, cshuffle - scheduler: str # intrawave, interwave - pad_m: bool - pad_n: bool - pad_k: bool - persistent: bool - - def is_valid(self) -> bool: - """Check if trait combination is valid""" - # Unsupported combinations - # Only 'mem' pipeline supports interwave scheduler. - # All compute pipelines (compv3/v4/v5/v6/async) only support intrawave. - unsupported = { - ("compv3", "cshuffle", "interwave"), - ("compv3", "default", "interwave"), - ("compv4", "cshuffle", "interwave"), - ("compv4", "default", "interwave"), - ("compv5", "cshuffle", "interwave"), - ("compv5", "default", "interwave"), - ("compv6", "cshuffle", "interwave"), - ("compv6", "default", "interwave"), - ("comp_async", "cshuffle", "interwave"), - ("comp_async", "default", "interwave"), - } - return (self.pipeline, self.epilogue, self.scheduler) not in unsupported + persistent: bool = False @dataclass @@ -345,89 +303,7 @@ class KernelConfig: # ============================================================================ -class TypeMappings: - """Centralized type mappings for code generation""" - - DTYPE_TO_CK = { - "fp16": "fp16_t", - "bf16": "bf16_t", - "fp32": "float", - "fp8": "fp8_t", - "bf8": "bf8_t", - "int8": "int8_t", - } - - # Fully-qualified types for use outside of 'using namespace ck_tile' scope - DTYPE_TO_CK_QUALIFIED = { - "fp16": "ck_tile::fp16_t", - "bf16": "ck_tile::bf16_t", - "fp32": "float", # Built-in type, no namespace - "fp8": "ck_tile::fp8_t", - "bf8": "ck_tile::bf8_t", - "int8": "int8_t", # Built-in type - } - - DTYPE_TO_DISPATCHER = { - "fp16": "DataType::FP16", - "bf16": "DataType::BF16", - "fp32": "DataType::FP32", - "fp8": "DataType::FP8", - "bf8": "DataType::BF8", - "int8": "DataType::INT8", - } - - LAYOUT_TO_CK = { - "r": "tensor_layout::gemm::RowMajor", - "c": "tensor_layout::gemm::ColumnMajor", - } - - LAYOUT_TO_DISPATCHER = { - "r": "LayoutTag::RowMajor", - "c": "LayoutTag::ColMajor", - } - - PIPELINE_TO_CK = { - "mem": "GemmPipelineAgBgCrMem", - "compv3": "GemmPipelineAgBgCrCompV3", - "compv4": "GemmPipelineAgBgCrCompV4", - "preshufflev2": "WeightPreshufflePipelineAGmemBGmemCRegV2", - } - - PIPELINE_TO_BASE = { - "mem": "BaseGemmPipelineAgBgCrMem", - "compv3": "BaseGemmPipelineAgBgCrCompV3", - "compv4": "BaseGemmPipelineAgBgCrCompV4", - "preshufflev2": "BaseWeightPreshufflePipelineAGmemBGmemCRegV2", - } - - PIPELINE_TO_DISPATCHER = { - "mem": "Pipeline::Mem", - "compv3": "Pipeline::CompV3", - "compv4": "Pipeline::CompV4", - "preshufflev2": "Pipeline::PreShuffleV2", - } - - SCHEDULER_TO_CK = { - "intrawave": "GemmPipelineScheduler::Intrawave", - "interwave": "GemmPipelineScheduler::Interwave", - "default": "GemmPipelineScheduler::Default", - } - - SCHEDULER_TO_DISPATCHER = { - "intrawave": "Scheduler::Intrawave", - "interwave": "Scheduler::Interwave", - "default": "Scheduler::Auto", - } - - EPILOGUE_TO_DISPATCHER = { - "cshuffle": "Epilogue::CShuffle", - "default": "Epilogue::Default", - } - - @staticmethod - def get_output_dtype(dtype: str) -> str: - """Get output datatype (fp8/bf8 -> fp16)""" - return "fp16" if dtype in ["fp8", "bf8"] else dtype +# TypeMappings imported from codegen_common as CommonTypeMappings -> TypeMappings alias # ============================================================================ @@ -1068,7 +944,11 @@ class UnifiedGemmCodegen: } def generate_all(self, parallel: bool = True) -> Dict: - """Generate all kernels""" + """Generate all kernels. + + When parallel=True, all configs across all variants are collected first, + then generated concurrently in a single thread pool for maximum throughput. + """ log.info("Generating GEMM kernels:") log.info(f" Datatype: {self.datatype}") log.info(f" Layout: {self.layout}") @@ -1078,49 +958,24 @@ class UnifiedGemmCodegen: results = {"kernels": [], "wrappers": [], "failed": []} - # Get configurations + # Collect ALL configs across all variants/preselected sets upfront + all_configs = [] if self.use_preselected: - configs = self._get_preselected_configs() - log.info(f" Total configurations: {len(configs)}") + all_configs = self._get_preselected_configs() + log.info(f" Total configurations: {len(all_configs)}") else: for variant in self.variants: - log.info(f"\nGenerating {variant.value} kernels...") configs = self._get_configs_for_variant(variant) - log.info(f" Configurations: {len(configs)}") + log.info(f" {variant.value}: {len(configs)} configurations") + all_configs.extend(configs) + log.info(f" Total across all variants: {len(all_configs)}") - if parallel: - with concurrent.futures.ThreadPoolExecutor() as executor: - futures = [ - executor.submit(self._generate_one, cfg) for cfg in configs - ] - for future in concurrent.futures.as_completed(futures): - try: - k, w = future.result() - results["kernels"].append(k) - results["wrappers"].append(w) - except Exception as e: - results["failed"].append(str(e)) - log.error(f"Failed: {e}") - else: - for cfg in configs: - try: - k, w = self._generate_one(cfg) - results["kernels"].append(k) - results["wrappers"].append(w) - except Exception as e: - results["failed"].append(str(e)) - log.error(f"Failed: {e}") - - # Generate registration header - if results["wrappers"]: - self._generate_registration_header(results["wrappers"]) - - return results - - # Generate from preselected set - if parallel: + # Generate all configs in a single parallel pass + if parallel and all_configs: with concurrent.futures.ThreadPoolExecutor() as executor: - futures = [executor.submit(self._generate_one, cfg) for cfg in configs] + futures = [ + executor.submit(self._generate_one, cfg) for cfg in all_configs + ] for future in concurrent.futures.as_completed(futures): try: k, w = future.result() @@ -1130,7 +985,7 @@ class UnifiedGemmCodegen: results["failed"].append(str(e)) log.error(f"Failed: {e}") else: - for cfg in configs: + for cfg in all_configs: try: k, w = self._generate_one(cfg) results["kernels"].append(k) @@ -1139,7 +994,6 @@ class UnifiedGemmCodegen: results["failed"].append(str(e)) log.error(f"Failed: {e}") - # Generate registration header if results["wrappers"]: self._generate_registration_header(results["wrappers"]) @@ -1638,12 +1492,19 @@ def main(): # Write to temp file and use as config import tempfile + import os as _os - with tempfile.NamedTemporaryFile( + _tmp_config = tempfile.NamedTemporaryFile( mode="w", suffix=".json", delete=False - ) as f: - json.dump(full_config, f) - args.config = Path(f.name) + ) + try: + json.dump(full_config, _tmp_config) + _tmp_config.close() + args.config = Path(_tmp_config.name) + except Exception: + _tmp_config.close() + _os.unlink(_tmp_config.name) + raise except json.JSONDecodeError as e: logging.error(f"Invalid tile-config-json: {e}") return 1 @@ -1672,7 +1533,7 @@ def main(): results = codegen.generate_all(parallel=not args.no_parallel) - logging.info("\n✅ Generation complete!") + logging.info("\nGeneration complete.") logging.info(f" Kernels: {len(results['kernels'])}") logging.info(f" Wrappers: {len(results['wrappers'])}") logging.info(f" Failed: {len(results['failed'])}") @@ -1684,7 +1545,7 @@ def main(): # Generate dispatcher registration if requested if args.register: - logging.info("\n📝 Generating dispatcher registration code...") + logging.info("\nGenerating dispatcher registration code...") try: from generate_dispatcher_registration import ( scan_generated_headers, @@ -1701,11 +1562,20 @@ def main(): ) generate_registration_cpp(kernels, reg_dir / "dispatcher_registration.cpp") - logging.info(f"✓ Generated registration code for {len(kernels)} kernels") + logging.info(f"Generated registration code for {len(kernels)} kernels") except Exception as e: logging.error(f"Failed to generate registration code: {e}") return 1 + # Clean up temp config file if we created one + if args.tile_config_json and args.config and args.config.exists(): + try: + import os as _os + + _os.unlink(args.config) + except OSError: + pass + return 0 if not results["failed"] else 1 diff --git a/dispatcher/codegen/unified_grouped_conv_codegen.py b/dispatcher/codegen/unified_grouped_conv_codegen.py new file mode 100644 index 0000000000..ff40cb4ed4 --- /dev/null +++ b/dispatcher/codegen/unified_grouped_conv_codegen.py @@ -0,0 +1,1757 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Unified Grouped Convolution Code Generator + +This is the unified code generator for all grouped convolution kernel variants: +- Forward grouped convolution +- Backward data grouped convolution +- Backward weight grouped convolution + +Generates both CK Tile kernels AND dispatcher wrappers. +Based on the GEMM codegen pattern. +""" + +import argparse +import logging +from pathlib import Path +from typing import List, Optional, Tuple, Union +from dataclasses import dataclass +from enum import Enum + +from codegen_common import ( + TileConfig, + TraitConfigBase, + parallel_generate, +) + +logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") +log = logging.getLogger(__name__) + +# Import architecture filter for GPU-specific validation +try: + from arch_filter import ArchFilter, OperatorType + + HAS_ARCH_FILTER = True +except ImportError: + HAS_ARCH_FILTER = False + ArchFilter = None + OperatorType = None + + +# ============================================================================ +# Configuration and Data Structures +# ============================================================================ + + +class GroupedConvVariant(Enum): + """Grouped convolution kernel variants""" + + FORWARD = "forward" + BACKWARD_DATA = "bwd_data" + BACKWARD_WEIGHT = "bwd_weight" + + +class GroupedConvLayout(Enum): + """Grouped convolution data layouts""" + + # 1D + NWGC = "NWGC" # Input/Output: N W G C + GKXC = "GKXC" # Weight: G K X C + NWGK = "NWGK" # Output: N W G K + + # 2D + NHWGC = "NHWGC" # Input: N H W G C + GKYXC = "GKYXC" # Weight: G K Y X C + NHWGK = "NHWGK" # Output: N H W G K + + # 3D + NDHWGC = "NDHWGC" # Input: N D H W G C + GKZYXC = "GKZYXC" # Weight: G K Z Y X C + NDHWGK = "NDHWGK" # Output: N D H W G K + + +@dataclass +class GroupedConvTraitConfig(TraitConfigBase): + """Kernel trait configuration for grouped convolution (extends TraitConfigBase). + + Conv-specific extensions beyond TraitConfigBase. These map to + GroupedConvTraits template parameters in grouped_convolution_utils.hpp: + - double_smem_buffer: ping-pong LDS for compute V4+ pipelines + - num_groups_to_merge: fuse multiple groups into one tile (NumGroupsToMerge) + - split_image: split spatial dims for large tensors (EnableSplitImage) + - explicit_gemm: use explicit GEMM path (ExplicitGemm) + - two_stage: two-stage bwd_weight with fp32 workspace + elementwise convert + + Note: CK Tile already uses long_index_t (64-bit) for group strides and + batch offsets, so there is no separate "large_tensor" flag. For large + spatial dimensions, use split_image=True instead. + """ + + double_smem_buffer: bool = False + num_groups_to_merge: int = 1 + split_image: bool = False + explicit_gemm: bool = False + two_stage: bool = False + + +# Backward compatibility alias +TraitConfig = GroupedConvTraitConfig + + +@dataclass +class GroupedConvKernelConfig: + """Complete grouped convolution kernel configuration""" + + tile: TileConfig + trait: GroupedConvTraitConfig + variant: GroupedConvVariant = GroupedConvVariant.FORWARD + ndim_spatial: int = 2 # 1D, 2D, or 3D + arch: str = "gfx942" # Target architecture + layout: Union[str, GroupedConvLayout] = ( + "nhwgc" # Data layout (e.g., "nhwgc", "ndhwgc") + ) + + # Vector sizes: a=4 for fp16 input (8-byte aligned global loads), + # b=8 for weight tensor, c=8 for output stores. These match the + # CK Tile default vectorization widths for fp16 on CDNA3 (gfx942). + vector_size_a: int = 4 + vector_size_b: int = 8 + vector_size_c: int = 8 + vector_sizes: Optional[Tuple[int, int, int]] = None + + # Occupancy parameters + block_per_cu: int = 1 + num_wave_groups: int = 1 + num_groups_to_merge: int = 1 + + # Double buffering + double_smem_buffer: bool = False + + def __post_init__(self): + if self.vector_sizes is not None: + self.vector_size_a, self.vector_size_b, self.vector_size_c = ( + self.vector_sizes[:3] + ) + # Sync trait fields with top-level fields (trait is source of truth + # when both are specified, but top-level overrides default trait values). + if self.double_smem_buffer and not self.trait.double_smem_buffer: + self.trait.double_smem_buffer = self.double_smem_buffer + elif self.trait.double_smem_buffer: + self.double_smem_buffer = self.trait.double_smem_buffer + if self.num_groups_to_merge != 1 and self.trait.num_groups_to_merge == 1: + self.trait.num_groups_to_merge = self.num_groups_to_merge + elif self.trait.num_groups_to_merge != 1: + self.num_groups_to_merge = self.trait.num_groups_to_merge + + def _layout_str(self) -> str: + """Get layout as lowercase string for naming.""" + if hasattr(self.layout, "value"): + return self.layout.value.lower() + return str(self.layout).lower() + + def name(self, datatype: str) -> str: + """ + Generate kernel name that uniquely identifies the kernel configuration. + + Format: grouped_conv_{variant}_{dtype}_{layout}_{ndim}d_{pipeline}_{epilogue}_{scheduler} + _{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k} + _{warp_tile_m}x{warp_tile_n}x{warp_tile_k} + [_vec{a}_{b}_{c}][_bpc{n}][_wg{n}][_gm{n}][_dsb][_pad{mnk}] + + All parameters that affect kernel behavior MUST be included to ensure + unique names for unique configurations: + - Variant (fwd/bwd_data/bwd_weight) + - Data type + - Layout (nhwgc, nchw, ndhwgc, etc.) + - Spatial dimensions (2d/3d) + - Pipeline, epilogue, scheduler + - Tile, warp, warp_tile dimensions + - Vector sizes, occupancy hints (if non-default) + - Double SMEM buffer, padding flags + """ + t = self.tile + tr = self.trait + layout_str = self._layout_str() + + variant_str = { + GroupedConvVariant.FORWARD: "fwd", + GroupedConvVariant.BACKWARD_DATA: "bwd_data", + GroupedConvVariant.BACKWARD_WEIGHT: "bwd_weight", + }[self.variant] + + # Core identity: variant, dtype, layout, dims + name = ( + f"grouped_conv_{variant_str}_{datatype}_{layout_str}_{self.ndim_spatial}d" + ) + + # Pipeline configuration + name += f"_{tr.pipeline}_{tr.epilogue}_{tr.scheduler}" + + # Block tile dimensions (M_Tile x N_Tile x K_Tile) + name += f"_{t.tile_m}x{t.tile_n}x{t.tile_k}" + + # Wave distribution (M_Warp x N_Warp x K_Warp) + name += f"_{t.warp_m}x{t.warp_n}x{t.warp_k}" + + # Warp tile dimensions (M_Warp_Tile x N_Warp_Tile x K_Warp_Tile) + name += f"_{t.warp_tile_m}x{t.warp_tile_n}x{t.warp_tile_k}" + + # Vector sizes (only if non-default) + if (self.vector_size_a, self.vector_size_b, self.vector_size_c) != (4, 8, 8): + name += ( + f"_vec{self.vector_size_a}_{self.vector_size_b}_{self.vector_size_c}" + ) + + # Occupancy hints (only if non-default) + if self.block_per_cu != 1: + name += f"_bpc{self.block_per_cu}" + + if self.num_wave_groups != 1: + name += f"_wg{self.num_wave_groups}" + + if self.num_groups_to_merge != 1: + name += f"_gm{self.num_groups_to_merge}" + + # Double SMEM buffer (for compute V4+) + if self.double_smem_buffer or tr.double_smem_buffer: + name += "_dsb" + + # Two-stage bwd_weight (fp32 workspace + elementwise convert) + if tr.two_stage: + name += "_2stage" + + # Padding suffix (only if not all enabled) + if not (tr.pad_m and tr.pad_n and tr.pad_k): + name += f"_pad{int(tr.pad_m)}{int(tr.pad_n)}{int(tr.pad_k)}" + + return name + + def is_valid_for_arch(self, arch: Optional[str] = None) -> bool: + """Check if configuration is valid for target architecture""" + target_arch = arch if arch is not None else self.arch + + # Check trait validity + if not self.trait.is_valid(): + return False + + # Backward operations have stricter pipeline requirements: + # - Backward weight: compv4/compv5 have transpose_tile2d issues + # - Backward data: compv4 has get_length issues in bwd_data kernel + # Both backward operations ONLY support compv3 and mem pipelines + if self.variant in ( + GroupedConvVariant.BACKWARD_WEIGHT, + GroupedConvVariant.BACKWARD_DATA, + ): + if self.trait.pipeline not in ("compv3", "mem"): + return False + + # Check warp configuration (from arch_specs) + try: + from arch_specs_generated import WARP_SUPPORTED_COMBINATIONS + + supported = WARP_SUPPORTED_COMBINATIONS.get(target_arch) + if supported is None: + return False # Unknown architecture + warp_cfg = [self.tile.warp_m, self.tile.warp_n, self.tile.warp_k] + if warp_cfg not in supported: + return False + except ImportError: + pass # Allow if arch_specs not available + + return True + + +# ============================================================================ +# Type Mappings +# ============================================================================ + + +class GroupedConvTypeMappings: + """Centralized type mappings for grouped convolution code generation""" + + DTYPE_TO_CK = { + "fp16": "half_t", + "bf16": "bf16_t", + "fp32": "float", + } + + # CK Tile conv pipelines (from conv_configs.hpp PipelineTypeTraits). + # basic_v1/mem/compv3 use GroupedConvUniversalPipelineAgBgCrPolicy; + # compv4/compv5/compv6/comp_async/basic_async_v1 use their own default policy. + PIPELINE_TO_CK = { + "basic_v1": "GemmPipeline::BASIC_V1", + "mem": "GemmPipeline::MEMORY", + "compv3": "GemmPipeline::COMPUTE_V3", + "compv4": "GemmPipeline::COMPUTE_V4", + "compv5": "GemmPipeline::COMPUTE_V5", + "compv6": "GemmPipeline::COMPUTE_V6", + "comp_async": "GemmPipeline::COMPUTE_ASYNC", + "basic_async_v1": "GemmPipeline::BASIC_ASYNC_V1", + } + + SCHEDULER_TO_CK = { + "intrawave": "GemmPipelineScheduler::Intrawave", + "interwave": "GemmPipelineScheduler::Interwave", + } + + LAYOUT_1D = { + "in": "tensor_layout::convolution::NWGC", + "wei": "tensor_layout::convolution::GKXC", + "out": "tensor_layout::convolution::NWGK", + } + + LAYOUT_2D = { + "in": "tensor_layout::convolution::NHWGC", + "wei": "tensor_layout::convolution::GKYXC", + "out": "tensor_layout::convolution::NHWGK", + } + + LAYOUT_3D = { + "in": "tensor_layout::convolution::NDHWGC", + "wei": "tensor_layout::convolution::GKZYXC", + "out": "tensor_layout::convolution::NDHWGK", + } + + @classmethod + def get_layouts(cls, ndim: int) -> dict: + if ndim == 1: + return cls.LAYOUT_1D + elif ndim == 2: + return cls.LAYOUT_2D + else: + return cls.LAYOUT_3D + + +# ============================================================================ +# CK Tile Grouped Conv Kernel Generator +# ============================================================================ + + +class CKTileGroupedConvKernelGenerator: + """Generates CK Tile grouped convolution kernel instance code""" + + def __init__( + self, + datatype: str, + variant: GroupedConvVariant = GroupedConvVariant.FORWARD, + ): + self.datatype = datatype + self.variant = variant + self.tm = GroupedConvTypeMappings() + + def generate(self, config: GroupedConvKernelConfig) -> str: + """Generate complete CK Tile grouped convolution kernel""" + kernel_name = config.name(self.datatype) + return f"""{self._header(kernel_name, config)} +{self._config_struct(config, kernel_name)} +{self._kernel_instance(config, kernel_name)} +""" + + def _header(self, kernel_name: str, config: GroupedConvKernelConfig) -> str: + """Generate header includes based on variant""" + if self.variant == GroupedConvVariant.BACKWARD_DATA: + kernel_header = "grouped_convolution_backward_data_kernel.hpp" + elif self.variant == GroupedConvVariant.BACKWARD_WEIGHT: + kernel_header = "grouped_convolution_backward_weight_kernel.hpp" + else: + kernel_header = "grouped_convolution_forward_kernel.hpp" + + elementwise_include = "" + if config.trait.two_stage: + elementwise_include = '\n#include "ck_tile/ops/elementwise.hpp"' + + return f"""// SPDX-License-Identifier: MIT +// Auto-generated CK Tile Grouped Convolution kernel: {kernel_name} +// Variant: {self.variant.value} +#pragma once + +#include +#include +#include +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/grouped_convolution.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/grouped_convolution/kernel/{kernel_header}" +#include "ck_tile/ops/grouped_convolution/pipeline/grouped_conv_universal_pipeline_ag_bg_cr_policy.hpp"{elementwise_include} + +using namespace ck_tile; +""" + + def _config_struct(self, config: GroupedConvKernelConfig, kernel_name: str) -> str: + """Generate config struct""" + t = config.tile + tr = config.trait + layouts = self.tm.get_layouts(config.ndim_spatial) + + return f""" +// Kernel configuration +struct {kernel_name}_Config {{ + // Data types + using InDataType = {self.tm.DTYPE_TO_CK[self.datatype]}; + using WeiDataType = {self.tm.DTYPE_TO_CK[self.datatype]}; + using AccDataType = float; + using OutDataType = {self.tm.DTYPE_TO_CK[self.datatype]}; + + // Layouts + using InLayout = {layouts["in"]}; + using WeiLayout = {layouts["wei"]}; + using OutLayout = {layouts["out"]}; + + // Tile shape + static constexpr index_t M_Tile = {t.tile_m}; + static constexpr index_t N_Tile = {t.tile_n}; + static constexpr index_t K_Tile = {t.tile_k}; + + static constexpr index_t M_Warp = {t.warp_m}; + static constexpr index_t N_Warp = {t.warp_n}; + static constexpr index_t K_Warp = {t.warp_k}; + + static constexpr index_t M_Warp_Tile = {t.warp_tile_m}; + static constexpr index_t N_Warp_Tile = {t.warp_tile_n}; + static constexpr index_t K_Warp_Tile = {t.warp_tile_k}; + + // Vector sizes + static constexpr index_t VectorSizeA = {config.vector_size_a}; + static constexpr index_t VectorSizeB = {config.vector_size_b}; + static constexpr index_t VectorSizeC = {config.vector_size_c}; + + // Padding + static constexpr bool kPadM = {str(tr.pad_m).lower()}; + static constexpr bool kPadN = {str(tr.pad_n).lower()}; + static constexpr bool kPadK = {str(tr.pad_k).lower()}; + + // Pipeline & Epilogue + static constexpr auto Pipeline = {self.tm.PIPELINE_TO_CK[tr.pipeline]}; + static constexpr auto Scheduler = {self.tm.SCHEDULER_TO_CK[tr.scheduler]}; + static constexpr bool DoubleSmemBuffer = {str(tr.double_smem_buffer).lower()}; + static constexpr bool UseCShuffleEpilogue = {str(tr.epilogue == "cshuffle").lower()}; + + // Other params + static constexpr int kBlockPerCu = {config.block_per_cu}; + static constexpr index_t NumWaveGroups = {config.num_wave_groups}; + static constexpr index_t NumGroupsToMerge = {tr.num_groups_to_merge}; + static constexpr bool EnableSplitImage = {str(tr.split_image).lower()}; + static constexpr bool ExplicitGemm = {str(tr.explicit_gemm).lower()}; + static constexpr index_t NDimSpatial = {config.ndim_spatial}; + + // Target architecture + static constexpr const char* TargetArch = "{config.arch}"; +}}; +""" + + def _kernel_instance( + self, config: GroupedConvKernelConfig, kernel_name: str + ) -> str: + """Generate kernel instantiation code with launch function""" + tr = config.trait + + if self.variant == GroupedConvVariant.BACKWARD_WEIGHT and tr.two_stage: + return self._kernel_instance_two_stage(config, kernel_name) + + # Variant-specific configuration + if self.variant == GroupedConvVariant.BACKWARD_DATA: + host_args_type = "GroupedConvBwdDataHostArgs" + kernel_type = "GroupedConvolutionBackwardDataKernel" + gemm_traits = "GroupedConvImplicitGemmTraitsBwdData" + layout_suffix = "BwdData" + # For bwd_data: A=dOutput, B=Weight, C=dInput + a_dtype = "OutDataType" + b_dtype = "WeiDataType" + c_dtype = "InDataType" + gemm_k_calc = "args.K_ * std::accumulate(args.filter_spatial_lengths_.begin(), args.filter_spatial_lengths_.end()" + direction_prefix = "BWD_DATA" + launcher_alias = "SelectedConvBwdDataLauncher" + elif self.variant == GroupedConvVariant.BACKWARD_WEIGHT: + host_args_type = "GroupedConvBwdWeightHostArgs" + kernel_type = "GroupedConvolutionBackwardWeightKernel" + gemm_traits = "GroupedConvImplicitGemmTraitsBwdWeight" + layout_suffix = "BwdWeight" + # For bwd_weight: A=dOutput, B=Input, C=dWeight (per CK Tile invoker) + a_dtype = "OutDataType" + b_dtype = "InDataType" + c_dtype = "WeiDataType" + gemm_k_calc = "args.N_ * std::accumulate(args.output_spatial_lengths_.begin(), args.output_spatial_lengths_.end()" + direction_prefix = "BWD_WEIGHT" + launcher_alias = "SelectedConvBwdWeightLauncher" + else: # Forward + host_args_type = "GroupedConvFwdHostArgs<>" + kernel_type = "GroupedConvolutionForwardKernel" + gemm_traits = "GroupedConvImplicitGemmTraitsFwd" + layout_suffix = "Fwd" + a_dtype = "InDataType" + b_dtype = "WeiDataType" + c_dtype = "OutDataType" + gemm_k_calc = "args.C_ * std::accumulate(args.filter_spatial_lengths_.begin(), args.filter_spatial_lengths_.end()" + direction_prefix = "FWD" + launcher_alias = "SelectedConvKernelLauncher" + + # Create valid C++ namespace name + ns_name = "ns_" + kernel_name.replace("-", "_") + + return f""" +// Unique namespace for this kernel to avoid conflicts when including multiple kernels +namespace {ns_name} {{ + +// Bring Config into namespace +using Config = {kernel_name}_Config; + +// Kernel name for identification +constexpr const char* CONV_{direction_prefix}_KERNEL_NAME = "{kernel_name}"; + +// Selected kernel alias +using SelectedConv{direction_prefix.title()}Kernel = Config; + +// ============================================================================= +// Kernel Launch Implementation ({self.variant.value}) +// ============================================================================= + +struct {kernel_name}_Launcher {{ + using KernelConfig = Config; // Use the Config alias from namespace + using InDataType = typename Config::InDataType; + using WeiDataType = typename Config::WeiDataType; + using OutDataType = typename Config::OutDataType; + using AccDataType = typename Config::AccDataType; + using InLayout = typename Config::InLayout; + using WeiLayout = typename Config::WeiLayout; + using OutLayout = typename Config::OutLayout; + + static constexpr index_t NDimSpatial = Config::NDimSpatial; + + // Implicit GEMM shape + using GemmShape = TileGemmShape< + sequence, + sequence, + sequence>; + + // Convolution traits + static constexpr auto ConvSpec = ConvolutionSpecialization::Default; + using GroupedConvTraitsType = GroupedConvTraits< + NDimSpatial, ConvSpec, InLayout, WeiLayout, tuple<>, OutLayout, + Config::VectorSizeA, Config::VectorSizeB, Config::VectorSizeC, + Config::NumGroupsToMerge, Config::EnableSplitImage, Config::ExplicitGemm>; + + // Tile partitioner + using TilePartitioner = GemmSpatiallyLocalTilePartitioner< + GemmShape, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>; + + // Universal traits - layout suffix changes per variant + using GemmUniversalTraits = TileGemmUniversalTraits< + GroupedConvTraitsType::FixedGemmParams::kPadM, + GroupedConvTraitsType::FixedGemmParams::kPadN, + GroupedConvTraitsType::FixedGemmParams::kPadK, + Config::DoubleSmemBuffer, + typename GroupedConvTraitsType::AsLayout{layout_suffix}, + typename GroupedConvTraitsType::BsLayout{layout_suffix}, + typename GroupedConvTraitsType::CLayout{layout_suffix}, + GroupedConvTraitsType::FixedGemmParams::TransposeC, + GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity, + GroupedConvTraitsType::FixedGemmParams::Persistent, + Config::NumWaveGroups>; + + // Pipeline problem - data types change per variant + using GemmPipelineProblem = GemmPipelineProblem< + {a_dtype}, {b_dtype}, AccDataType, GemmShape, + typename GroupedConvTraitsType::template {gemm_traits}, + element_wise::PassThrough, element_wise::PassThrough, {c_dtype}, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, GroupedConvTraitsType::VectorSizeB>; + + // Base pipeline for tail handling + using BaseGemmPipeline = {self._get_base_pipeline(tr.pipeline)}; + + static float launch(const {host_args_type}& args, const stream_config& s) {{ + const index_t gemm_k = {gemm_k_calc}, 1, std::multiplies()); + + const index_t k_grain = args.k_batch * Config::K_Tile; + const index_t K_split = (gemm_k + k_grain - 1) / k_grain * Config::K_Tile; + const index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{{0}}; + + constexpr auto scheduler = Config::Scheduler; + + using UniversalGemmProblem = UniversalGemmPipelineProblem< + {a_dtype}, {b_dtype}, AccDataType, GemmShape, GemmUniversalTraits, + scheduler, + element_wise::PassThrough, element_wise::PassThrough, {c_dtype}, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, GroupedConvTraitsType::VectorSizeB>; + + using GemmPipeline = {self._get_pipeline_template_args(tr.pipeline, "UniversalGemmProblem")}; + + using ConvEpilogue = CShuffleEpilogue, AccDataType, {c_dtype}, + typename GroupedConvTraitsType::ImplicitGemmDsLayout, + typename GroupedConvTraitsType::FixedGemmParams::ELayout, + element_wise::PassThrough, + TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, + Config::M_Warp, Config::N_Warp, Config::M_Warp_Tile, + Config::N_Warp_Tile, Config::K_Warp_Tile, + GroupedConvTraitsType::FixedGemmParams::TransposeC, + Config::NumWaveGroups, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + Config::VectorSizeC, false, 1, Config::DoubleSmemBuffer>>; + + using Kernel = {kernel_type}< + GroupedConvTraitsType, TilePartitioner, GemmPipeline, ConvEpilogue>; + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {{ + auto kargs = Kernel::MakeKernelArgs(args); + + if (!Kernel::IsSupportedArgument(kargs)) {{ + throw std::runtime_error("Arguments not supported for grouped conv kernel"); + }} + + const dim3 grids = Kernel::GridSize(kargs); + const dim3 blocks = Kernel::BlockSize(); + + ave_time = launch_kernel(s, make_kernel( + Kernel{{}}, grids, blocks, 0, kargs)); + + return ave_time; + }}; + + BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); + return ave_time; + }} +}}; + +// Launcher alias for tile_engine compatibility +using {launcher_alias} = {kernel_name}_Launcher; + +}} // namespace {ns_name} + +// Export specific launcher to global namespace +using {kernel_name}_Launcher = {ns_name}::{kernel_name}_Launcher; + +// When used with -include compiler flag, export aliases to global namespace +#ifdef CK_TILE_SINGLE_KERNEL_INCLUDE +using {launcher_alias} = {ns_name}::{launcher_alias}; +constexpr const char* CONV_{direction_prefix}_KERNEL_NAME = {ns_name}::CONV_{direction_prefix}_KERNEL_NAME; +#endif +""" + + # Pipelines that accept GroupedConvUniversalPipelineAgBgCrPolicy + # as a second template parameter for conv-specific LDS layout. + # (from conv_configs.hpp PipelineTypeTraits -- basic_v1/mem/compv3) + # CompV4/V5/V6/comp_async/basic_async_v1 use their own default policies. + _CONV_POLICY_PIPELINES = {"basic_v1", "mem", "compv3"} + + def _get_pipeline(self, pipeline: str) -> str: + """Get pipeline class name.""" + pipelines = { + "basic_v1": "GemmPipelineAGmemBGmemCRegV1", + "mem": "GemmPipelineAgBgCrMem", + "compv3": "GemmPipelineAgBgCrCompV3", + "compv4": "GemmPipelineAgBgCrCompV4", + "compv5": "GemmPipelineAgBgCrCompV5", + "compv6": "GemmPipelineAgBgCrCompV6", + "comp_async": "GemmPipelineAgBgCrCompAsync", + "basic_async_v1": "GemmPipelineAGmemBGmemCRegAsyncV1", + } + return pipelines.get(pipeline, "GemmPipelineAgBgCrCompV3") + + def _get_pipeline_template_args(self, pipeline: str, problem_type: str) -> str: + """Get full template argument list for pipeline instantiation. + + For basic_v1/mem/compv3, passes GroupedConvUniversalPipelineAgBgCrPolicy + as a second template argument for conv-specific LDS banking. + """ + base = self._get_pipeline(pipeline) + if pipeline in self._CONV_POLICY_PIPELINES: + return f"{base}<{problem_type}, GroupedConvUniversalPipelineAgBgCrPolicy>" + return f"{base}<{problem_type}>" + + def _get_base_pipeline(self, pipeline: str) -> str: + """Get base pipeline class name (used for tail handling only). + + Note: basic_async_v1 inherits from BaseGemmPipelineAGmemBGmemCRegV1 + (there is no separate BaseGemmPipelineAGmemBGmemCRegAsyncV1). + """ + pipelines = { + "basic_v1": "BaseGemmPipelineAGmemBGmemCRegV1", + "mem": "BaseGemmPipelineAgBgCrMem", + "compv3": "BaseGemmPipelineAgBgCrCompV3", + "compv4": "BaseGemmPipelineAgBgCrCompV4", + "compv5": "BaseGemmPipelineAgBgCrCompV5", + "compv6": "BaseGemmPipelineAgBgCrCompV6", + "comp_async": "BaseGemmPipelineAgBgCrCompAsync", + "basic_async_v1": "BaseGemmPipelineAGmemBGmemCRegV1", + } + return pipelines.get(pipeline, "BaseGemmPipelineAgBgCrCompV3") + + def _kernel_instance_two_stage( + self, config: GroupedConvKernelConfig, kernel_name: str + ) -> str: + """Generate two-stage bwd_weight kernel: GEMM into fp32 workspace + ElementWise convert. + + Mirrors grouped_convolution_backward_weight_two_stage_invoker.hpp from + example/ck_tile/20_grouped_convolution/. + """ + tr = config.trait + ns_name = "ns_" + kernel_name.replace("-", "_") + direction_prefix = "BWD_WEIGHT" + launcher_alias = "SelectedConvBwdWeightLauncher" + + return f""" +namespace {ns_name} {{ + +using Config = {kernel_name}_Config; +constexpr const char* CONV_{direction_prefix}_KERNEL_NAME = "{kernel_name}"; +using SelectedConv{direction_prefix.title()}Kernel = Config; + +struct {kernel_name}_Launcher {{ + using KernelConfig = Config; + using InDataType = typename Config::InDataType; + using WeiDataType = typename Config::WeiDataType; + using OutDataType = typename Config::OutDataType; + using AccDataType = typename Config::AccDataType; + using InLayout = typename Config::InLayout; + using WeiLayout = typename Config::WeiLayout; + using OutLayout = typename Config::OutLayout; + using WorkspaceDataType = float; + + static constexpr index_t NDimSpatial = Config::NDimSpatial; + // Two-stage forces VectorSizeC = 1 for workspace writes + static constexpr index_t VectorSizeC_TwoStage = 1; + + using GemmShape = TileGemmShape< + sequence, + sequence, + sequence>; + + static constexpr auto ConvSpec = ConvolutionSpecialization::Default; + using GroupedConvTraitsType = GroupedConvTraits< + NDimSpatial, ConvSpec, InLayout, WeiLayout, tuple<>, OutLayout, + Config::VectorSizeA, Config::VectorSizeB, VectorSizeC_TwoStage, + Config::NumGroupsToMerge, Config::EnableSplitImage, Config::ExplicitGemm>; + + using TilePartitioner = GemmSpatiallyLocalTilePartitioner< + GemmShape, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum, + GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>; + + using GemmUniversalTraits = TileGemmUniversalTraits< + GroupedConvTraitsType::FixedGemmParams::kPadM, + GroupedConvTraitsType::FixedGemmParams::kPadN, + GroupedConvTraitsType::FixedGemmParams::kPadK, + Config::DoubleSmemBuffer, + typename GroupedConvTraitsType::AsLayoutBwdWeight, + typename GroupedConvTraitsType::BsLayoutBwdWeight, + typename GroupedConvTraitsType::CLayoutBwdWeight, + GroupedConvTraitsType::FixedGemmParams::TransposeC, + GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity, + GroupedConvTraitsType::FixedGemmParams::Persistent, + Config::NumWaveGroups>; + + using GemmPipelineProblem = GemmPipelineProblem< + OutDataType, InDataType, AccDataType, GemmShape, + typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsBwdWeight, + element_wise::PassThrough, element_wise::PassThrough, WeiDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, GroupedConvTraitsType::VectorSizeB>; + + using BaseGemmPipeline = {self._get_base_pipeline(tr.pipeline)}; + + static float launch(const GroupedConvBwdWeightHostArgs& args, const stream_config& s) {{ + const index_t gemm_k = args.N_ * std::accumulate( + args.output_spatial_lengths_.begin(), args.output_spatial_lengths_.end(), + 1, std::multiplies()); + + const index_t k_grain = args.k_batch * Config::K_Tile; + const index_t K_split = (gemm_k + k_grain - 1) / k_grain * Config::K_Tile; + const index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{{0}}; + + constexpr auto scheduler = Config::Scheduler; + + using UniversalGemmProblem = UniversalGemmPipelineProblem< + OutDataType, InDataType, AccDataType, GemmShape, GemmUniversalTraits, + scheduler, + element_wise::PassThrough, element_wise::PassThrough, WeiDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, GroupedConvTraitsType::VectorSizeB>; + + using GemmPipeline = {self._get_pipeline_template_args(tr.pipeline, "UniversalGemmProblem")}; + + // Epilogue writes to fp32 workspace (not fp16 output) + using ConvEpilogue = CShuffleEpilogue, AccDataType, WorkspaceDataType, + typename GroupedConvTraitsType::ImplicitGemmDsLayout, + typename GroupedConvTraitsType::FixedGemmParams::ELayout, + element_wise::PassThrough, + TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, + Config::M_Warp, Config::N_Warp, Config::M_Warp_Tile, + Config::N_Warp_Tile, Config::K_Warp_Tile, + GroupedConvTraitsType::FixedGemmParams::TransposeC, + Config::NumWaveGroups, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeC>>; + + using Kernel = GroupedConvolutionBackwardWeightKernel< + GroupedConvTraitsType, TilePartitioner, GemmPipeline, ConvEpilogue>; + + // ElementWise kernel: fp32 workspace -> fp16/bf16 output + using XElementwiseOp = element_wise::UnaryConvert; + using EwBlockTile = sequence<2048>; + using EwBlockWarps = sequence<8>; + using EwWarpTile = sequence<64>; + using EwShape = ElementWiseShape; + using EwProblem = ElementWisePipelineProblem< + WorkspaceDataType, WorkspaceDataType, WeiDataType, EwShape, XElementwiseOp>; + using EwKernel = ElementWiseKernel; + + // Workspace: G * K * C * product(filter_spatial) elements in fp32 + const index_t spatial_accum = std::accumulate( + args.filter_spatial_lengths_.begin(), args.filter_spatial_lengths_.end(), + 1, std::multiplies()); + DeviceMem ws_buf(args.G_ * args.K_ * args.C_ * spatial_accum * sizeof(WorkspaceDataType)); + + GroupedConvBwdWeightHostArgs ws_args(args); + auto* c_ptr = ws_args.wei_ptr; + ws_args.wei_ptr = ws_buf.GetDeviceBuffer(); + + auto kargs = Kernel::MakeKernelArgs(ws_args); + + if(!Kernel::IsSupportedArgument(kargs)) {{ + throw std::runtime_error("Arguments not supported for two-stage bwd_weight kernel"); + }} + + const dim3 grids = Kernel::GridSize(kargs); + const dim3 blocks = Kernel::BlockSize(); + + // ElementWise kernel setup + const index_t ew_block_size = EwKernel::BlockSize(); + const index_t total_elems = args.G_ * args.K_ * args.C_ * spatial_accum; + constexpr index_t elems_per_block = EwBlockTile::at(number<0>{{}}); + const index_t ew_grid_size = (total_elems + elems_per_block - 1) / elems_per_block; + + auto ew_shape = make_tuple(args.G_ * args.K_, + args.C_ * spatial_accum); + auto ew_inputs = make_tuple(static_cast(ws_args.wei_ptr)); + + if(!EwKernel::IsSupportedArgument(ew_shape)) {{ + throw std::runtime_error("ElementWise arguments not supported for two-stage convert"); + }} + + auto preprocess = [&]() {{ + if(kargs.k_batch > 1) + hip_check_error(hipMemsetAsync( + ws_args.wei_ptr, 0, + total_elems * sizeof(WorkspaceDataType), + s.stream_id_)); + }}; + + ave_time = launch_kernel_time_mask( + s, preprocess, + make_kernel(Kernel{{}}, grids, blocks, 0, kargs), + make_kernel( + EwKernel{{}}, ew_grid_size, ew_block_size, 0, + ew_shape, + make_tuple(args.C_ * spatial_accum, 1), + make_tuple(args.C_ * spatial_accum, 1), + ew_inputs, + static_cast(c_ptr))); + + return ave_time; + }} +}}; + +using {launcher_alias} = {kernel_name}_Launcher; + +}} // namespace {ns_name} + +using {kernel_name}_Launcher = {ns_name}::{kernel_name}_Launcher; + +#ifdef CK_TILE_SINGLE_KERNEL_INCLUDE +using {launcher_alias} = {ns_name}::{launcher_alias}; +constexpr const char* CONV_{direction_prefix}_KERNEL_NAME = {ns_name}::CONV_{direction_prefix}_KERNEL_NAME; +#endif +""" + + +# ============================================================================ +# Dispatcher Wrapper Generator +# ============================================================================ + + +class GroupedConvDispatcherWrapperGenerator: + """Generates dispatcher integration wrapper following GEMM pattern""" + + # Static mappings for pipeline and scheduler enum names (matches kernel_key.hpp) + PIPELINE_TO_DISPATCHER = { + "mem": "Pipeline::Mem", + "compv3": "Pipeline::CompV3", + "compv4": "Pipeline::CompV4", + "compv5": "Pipeline::CompV5", + "preshufflev1": "Pipeline::PreShuffleV1", + "preshufflev2": "Pipeline::PreShuffleV2", + } + + SCHEDULER_TO_DISPATCHER = { + "default": "Scheduler::Default", + "intrawave": "Scheduler::Intrawave", + "interwave": "Scheduler::Interwave", + } + + def __init__( + self, + datatype: str, + variant: GroupedConvVariant = GroupedConvVariant.FORWARD, + ): + self.datatype = datatype + self.variant = variant + + def _pipeline_to_dispatcher(self, pipeline: str) -> str: + """Convert pipeline string to dispatcher enum value""" + return self.PIPELINE_TO_DISPATCHER.get( + pipeline.lower(), f"Pipeline::{pipeline.capitalize()}" + ) + + def _scheduler_to_dispatcher(self, scheduler: str) -> str: + """Convert scheduler string to dispatcher enum value""" + return self.SCHEDULER_TO_DISPATCHER.get( + scheduler.lower(), f"Scheduler::{scheduler.capitalize()}" + ) + + def generate( + self, + config: GroupedConvKernelConfig, + kernel_path: Path, + output_dir: Path, + ) -> str: + """Generate dispatcher wrapper with factory function for registry""" + kernel_name = config.name(self.datatype) + rel_path = kernel_path.relative_to(output_dir) + + # Determine launcher type based on variant + if self.variant == GroupedConvVariant.FORWARD: + launcher_alias = "SelectedConvKernelLauncher" + host_args_type = "GroupedConvFwdHostArgs<>" + conv_type_str = "forward" + elif self.variant == GroupedConvVariant.BACKWARD_DATA: + launcher_alias = "SelectedConvBwdDataLauncher" + host_args_type = "GroupedConvBwdDataHostArgs" + conv_type_str = "bwd_data" + else: # BACKWARD_WEIGHT + launcher_alias = "SelectedConvBwdWeightLauncher" + host_args_type = "GroupedConvBwdWeightHostArgs" + conv_type_str = "bwd_weight" + + return f"""// SPDX-License-Identifier: MIT +// Auto-generated dispatcher wrapper for: {kernel_name} +#pragma once + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/grouped_conv_utils.hpp" +#include "../{rel_path}" + +namespace ck_tile {{ +namespace dispatcher {{ +namespace generated {{ + +using ::ck_tile::dispatcher::GroupedConvKernelInstancePtr; +using ::ck_tile::dispatcher::GroupedConvKernelKey; +using ::ck_tile::dispatcher::DataType; +using ::ck_tile::dispatcher::LayoutTag; +using ::ck_tile::dispatcher::Pipeline; +using ::ck_tile::dispatcher::Scheduler; +using ::ck_tile::dispatcher::Epilogue; +using Priority = ::ck_tile::dispatcher::GroupedConvRegistry::Priority; + +// Factory function to create kernel instance for registry +inline GroupedConvKernelInstancePtr make_{kernel_name}(const std::string& gfx_arch = "gfx942") {{ + GroupedConvKernelKey key; + key.signature.dtype_in = DataType::FP16; + key.signature.dtype_wei = DataType::FP16; + key.signature.dtype_out = DataType::FP16; + key.signature.dtype_acc = DataType::FP32; + key.signature.layout = "nhwgc"; + key.signature.conv_type = "{conv_type_str}"; + key.signature.num_dims = {config.ndim_spatial}; + key.signature.groups = 1; + + key.algorithm.tile_shape = {{{config.tile.tile_m}, {config.tile.tile_n}, {config.tile.tile_k}}}; + key.algorithm.wave_shape = {{{config.tile.warp_m}, {config.tile.warp_n}, 1}}; + key.algorithm.warp_tile_shape = {{{config.tile.warp_tile_m}, {config.tile.warp_tile_n}, {config.tile.warp_tile_k}}}; + key.algorithm.pipeline = {self._pipeline_to_dispatcher(config.trait.pipeline)}; + key.algorithm.scheduler = {self._scheduler_to_dispatcher(config.trait.scheduler)}; + key.algorithm.epilogue = Epilogue::CShuffle; + key.gfx_arch = gfx_arch; + + // Create kernel instance that wraps the launcher + return std::make_shared( + key, + "{kernel_name}", + []({host_args_type}& args, const stream_config& cfg) -> float {{ + return {kernel_name}_Launcher::launch(args, cfg); + }} + ); +}} + +}} // namespace generated +}} // namespace dispatcher +}} // namespace ck_tile + +// Export launcher alias to global namespace for direct use +using {launcher_alias} = {kernel_name}_Launcher; +""" + + +# ============================================================================ +# Configuration Parser +# ============================================================================ + + +def get_default_configs( + arch: str = "gfx942", + variants: Optional[List[GroupedConvVariant]] = None, + ndims: Optional[List[int]] = None, +) -> List[GroupedConvKernelConfig]: + """Get default grouped convolution configurations for target architecture""" + configs = [] + + if variants is None: + variants = [GroupedConvVariant.FORWARD] + if ndims is None: + ndims = [2] + + # Valid configurations per variant (based on CK Tile example configs) + # Forward and Backward Data: standard GEMM-like tiles + fwd_bwd_data_tiles = [ + # (tile_m, tile_n, tile_k, warp_m, warp_n, warp_tile_m, warp_tile_n, warp_tile_k) + (128, 128, 32, 2, 2, 32, 32, 16), # Standard 128x128 + (256, 256, 32, 2, 2, 32, 32, 16), # Large 256x256 + (64, 64, 32, 1, 4, 16, 16, 16), # Small 64x64 + (128, 64, 32, 2, 2, 32, 32, 16), # Rectangular + (16, 64, 64, 1, 4, 16, 16, 32), # Tall and narrow + ] + + # Backward Weight: VERY specific tile configs that work with CK Tile's bwd_weight kernel + # Based on ConvConfigComputeV3 from CK Tile examples (example/ck_tile/20_grouped_convolution/) + # Note: Backward weight has strict constraints on warp configurations due to transpose_tile2d + # Only specific warp configs work: (1, 4, 1) and (4, 1, 1) are known to work + bwd_weight_tiles = [ + # (tile_m, tile_n, tile_k, warp_m, warp_n, warp_tile_m, warp_tile_n, warp_tile_k) + # ConvConfigComputeV3: The primary working config for backward weight + (16, 64, 64, 1, 4, 16, 16, 32), + ] + + for variant in variants: + # Select tile configs based on variant + if variant == GroupedConvVariant.BACKWARD_WEIGHT: + tile_configs = bwd_weight_tiles + # Backward weight ONLY supports compv3 (compv4/compv5 have transpose_tile2d issues) + pipelines = [("compv3", "cshuffle")] + # Also generate two-stage variants (fp32 workspace + elementwise convert) + two_stage_flags = [False, True] + elif variant == GroupedConvVariant.BACKWARD_DATA: + tile_configs = fwd_bwd_data_tiles + # Backward data ONLY supports compv3 (compv4 has get_length issues in bwd_data kernel) + pipelines = [("compv3", "cshuffle")] + two_stage_flags = [False] + else: + tile_configs = fwd_bwd_data_tiles + # Only forward grouped convolution supports both compv3 and compv4 + pipelines = [("compv3", "cshuffle"), ("compv4", "cshuffle")] + two_stage_flags = [False] + for ndim in ndims: + for pipeline, epilogue in pipelines: + for ( + tile_m, + tile_n, + tile_k, + warp_m, + warp_n, + warp_tile_m, + warp_tile_n, + warp_tile_k, + ) in tile_configs: + for two_stage in two_stage_flags: + adj_tile_k = tile_k * 2 if pipeline == "compv4" else tile_k + + trait = GroupedConvTraitConfig( + pipeline=pipeline, + scheduler="intrawave", + epilogue=epilogue, + double_smem_buffer=(pipeline == "compv4"), + pad_m=True, + pad_n=True, + pad_k=True, + two_stage=two_stage, + ) + + if not trait.is_valid(): + continue + + config = GroupedConvKernelConfig( + tile=TileConfig( + tile_m=tile_m, + tile_n=tile_n, + tile_k=adj_tile_k, + warp_m=warp_m, + warp_n=warp_n, + warp_k=1, + warp_tile_m=warp_tile_m, + warp_tile_n=warp_tile_n, + warp_tile_k=warp_tile_k, + ), + trait=trait, + variant=variant, + ndim_spatial=ndim, + arch=arch, + ) + + if config.is_valid_for_arch(): + configs.append(config) + + return configs + + +def get_arch_filter(): + """Get arch filter if available""" + try: + from arch_filter import ArchFilter + + return ArchFilter + except ImportError: + return None + + +# ============================================================================ +# Main Generator +# ============================================================================ + + +class _GenItem: + """Item for parallel generation with progress logging.""" + + def __init__( + self, + idx: int, + total: int, + config: GroupedConvKernelConfig, + datatype: str, + variant: GroupedConvVariant, + ): + self.idx = idx + self.total = total + self.config = config + self.datatype = datatype + self.variant = variant + + def __str__(self) -> str: + return f"kernel {self.idx}/{self.total}: {self.config.name(self.datatype)}" + + +class UnifiedGroupedConvCodegen: + """Main grouped convolution code generator""" + + def __init__( + self, + output_dir: Path, + gpu_target: str = "gfx942", + datatype: str = "fp16", + ndim_spatial: int = 2, + enable_arch_filter: bool = True, + ): + self.output_dir = output_dir + self.output_dir.mkdir(parents=True, exist_ok=True) + + # Create wrapper directory for dispatcher integration + self.wrapper_dir = self.output_dir / "dispatcher_wrappers" + self.wrapper_dir.mkdir(parents=True, exist_ok=True) + + self.generated_files: List[Path] = [] + self.generated_wrappers: List[Path] = [] + self.gpu_target = gpu_target + self.datatype = datatype + self.ndim_spatial = ndim_spatial + + # Initialize architecture filter for GPU-specific validation + self.arch_filter = None + if enable_arch_filter and HAS_ARCH_FILTER: + try: + self.arch_filter = ArchFilter(gpu_target, strict_mode=False) + log.info(f"Architecture filter enabled for {gpu_target}") + except ValueError as e: + log.warning(f"Could not create arch filter: {e}") + + def _get_configs(self) -> List[GroupedConvKernelConfig]: + """Get configurations for this codegen's datatype and ndim_spatial.""" + return get_default_configs( + arch=self.gpu_target, + variants=[ + GroupedConvVariant.FORWARD, + GroupedConvVariant.BACKWARD_DATA, + GroupedConvVariant.BACKWARD_WEIGHT, + ], + ndims=[self.ndim_spatial], + ) + + def _get_operator_type( + self, variant: GroupedConvVariant + ) -> Optional["OperatorType"]: + """Map GroupedConvVariant to OperatorType for arch validation""" + if OperatorType is None: + return None + + variant_to_operator = { + GroupedConvVariant.FORWARD: OperatorType.CONV_FWD, + GroupedConvVariant.BACKWARD_DATA: OperatorType.CONV_BWD_DATA, + GroupedConvVariant.BACKWARD_WEIGHT: OperatorType.CONV_BWD_WEIGHT, + } + return variant_to_operator.get(variant, OperatorType.CONV_FWD) + + def is_config_valid( + self, config: GroupedConvKernelConfig, datatype: str = "fp16" + ) -> bool: + """Validate configuration against architecture constraints""" + if not self.arch_filter or not HAS_ARCH_FILTER: + return True + + operator = self._get_operator_type(config.variant) + + return self.arch_filter.is_kernel_valid( + datatype_a=datatype, + datatype_b=datatype, + datatype_c=datatype, + tile_m=config.tile.tile_m, + tile_n=config.tile.tile_n, + tile_k=config.tile.tile_k, + warp_m=config.tile.warp_m, + warp_n=config.tile.warp_n, + warp_k=1, # Grouped conv typically uses warp_k=1 + warp_tile_m=config.tile.warp_tile_m, + warp_tile_n=config.tile.warp_tile_n, + warp_tile_k=config.tile.warp_tile_k, + pipeline=config.trait.pipeline, + epilogue=config.trait.epilogue, + scheduler=config.trait.scheduler, + operator=operator, + ) + + def generate_kernel( + self, + config: GroupedConvKernelConfig, + datatype: str, + variant: GroupedConvVariant = GroupedConvVariant.FORWARD, + ) -> Tuple[Path, Path]: + """Generate a single kernel file and dispatcher wrapper. Returns (kernel_path, wrapper_path).""" + kernel_gen = CKTileGroupedConvKernelGenerator(datatype, variant) + wrapper_gen = GroupedConvDispatcherWrapperGenerator(datatype, variant) + + kernel_name = config.name(datatype) + filename = f"{kernel_name}.hpp" + filepath = self.output_dir / filename + + # Generate kernel header + content = kernel_gen.generate(config) + filepath.write_text(content) + self.generated_files.append(filepath) + + # Generate dispatcher wrapper + wrapper_content = wrapper_gen.generate(config, filepath, self.output_dir) + wrapper_path = self.wrapper_dir / f"dispatcher_wrapper_{kernel_name}.hpp" + wrapper_path.write_text(wrapper_content) + self.generated_wrappers.append(wrapper_path) + + # Generate .cpp compilation unit for per-kernel parallel builds + cpp_filename = f"{kernel_name}.cpp" + cpp_filepath = self.output_dir / cpp_filename + cpp_content = f"""// SPDX-License-Identifier: MIT +// Auto-generated compilation unit for: {kernel_name} +// Enables per-kernel parallel compilation with make -j + +#include "{filename}" + +namespace ck_tile {{ namespace generated {{ + volatile bool _{kernel_name.replace("-", "_")}_loaded = true; +}} }} +""" + cpp_filepath.write_text(cpp_content) + + return filepath, wrapper_path + + def _generate_single_kernel(self, item: _GenItem): + """Generate one kernel (used by parallel_generate). Returns (kernel_path, wrapper_path) or raises.""" + kernel_path, wrapper_path = self.generate_kernel( + item.config, item.datatype, item.variant + ) + log.info( + "Generated kernel %d/%d: %s", + item.idx, + item.total, + item.config.name(item.datatype), + ) + return (kernel_path, wrapper_path) + + def generate_all( + self, + configs: Optional[List[GroupedConvKernelConfig]] = None, + datatypes: Optional[List[str]] = None, + parallel: bool = True, + ) -> dict: + """Generate all kernel files (optionally in parallel). + + Configs are filtered using architecture validation before generation. + Returns dict with keys: kernels, wrappers, failed. + """ + if configs is None: + configs = self._get_configs() + if datatypes is None: + datatypes = [self.datatype] + + results = {"kernels": [], "wrappers": [], "failed": []} + + # Filter configs using arch validation + valid_tasks = [] + rejected_count = 0 + + for datatype in datatypes: + for config in configs: + if self.is_config_valid(config, datatype): + valid_tasks.append((config, datatype, config.variant)) + else: + rejected_count += 1 + log.debug( + f"Rejected config for {self.gpu_target}: " + f"{config.tile.tile_m}x{config.tile.tile_n}x{config.tile.tile_k} " + f"variant={config.variant.value}" + ) + + if rejected_count > 0: + log.info( + f"Filtered {rejected_count} configs for {self.gpu_target}, " + f"{len(valid_tasks)} remaining" + ) + + total = len(valid_tasks) + items = [ + _GenItem(i, total, config, datatype, variant) + for i, (config, datatype, variant) in enumerate(valid_tasks) + ] + + def _safe_generate(item: _GenItem): + """Wrapper that catches exceptions for failure tracking.""" + try: + k, w = self._generate_single_kernel(item) + return ("ok", k, w, None) + except Exception as e: + return ("fail", None, None, str(e)) + + raw = parallel_generate( + _safe_generate, items, parallel=parallel and len(items) > 1 + ) + for r in raw: + if r[0] == "ok": + results["kernels"].append(r[1]) + results["wrappers"].append(r[2]) + else: + results["failed"].append(r[3]) + log.error("Failed: %s", r[3]) + + # Generate include_all_*.hpp headers for Python ctypes libraries + if results["wrappers"]: + self._generate_include_all_headers() + + return results + + def _generate_include_all_headers(self): + """Generate include_all_grouped_conv_*.hpp headers and registration header""" + # Scan output directory for ALL kernel files (not just this run's generated_files) + # This handles the case where fwd and bwd kernels are generated in separate make targets + fwd_headers = [] + bwd_data_headers = [] + bwd_weight_headers = [] + fwd_kernels = [] + bwd_data_kernels = [] + bwd_weight_kernels = [] + + for filepath in self.output_dir.glob("grouped_conv_*.hpp"): + name = filepath.name + kernel_name = name[:-4] + if name.startswith("grouped_conv_fwd_"): + fwd_headers.append(name) + fwd_kernels.append(kernel_name) + elif name.startswith(("grouped_conv_bwd_data_", "grouped_conv_bwdd_")): + bwd_data_headers.append(name) + bwd_data_kernels.append(kernel_name) + elif name.startswith(("grouped_conv_bwd_weight_", "grouped_conv_bwdw_")): + bwd_weight_headers.append(name) + bwd_weight_kernels.append(kernel_name) + + headers_to_generate = [ + ("include_all_grouped_conv_fwd_kernels.hpp", fwd_headers, "forward"), + ( + "include_all_grouped_conv_bwd_data_kernels.hpp", + bwd_data_headers, + "backward data", + ), + ( + "include_all_grouped_conv_bwd_weight_kernels.hpp", + bwd_weight_headers, + "backward weight", + ), + ] + + for header_name, kernel_headers, variant_desc in headers_to_generate: + header_path = self.output_dir / header_name + includes = "\n".join(f'#include "{h}"' for h in sorted(kernel_headers)) + + # Pick the first kernel as the default Selected*Launcher + if kernel_headers: + first_kernel = sorted(kernel_headers)[0][:-4] # Remove .hpp + if variant_desc == "forward": + launcher_alias = ( + f"using SelectedConvKernelLauncher = {first_kernel}_Launcher;" + ) + elif variant_desc == "backward data": + launcher_alias = ( + f"using SelectedConvBwdDataLauncher = {first_kernel}_Launcher;" + ) + else: # backward weight + launcher_alias = f"using SelectedConvBwdWeightLauncher = {first_kernel}_Launcher;" + else: + launcher_alias = "// No kernels generated for this variant" + + content = f"""// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +// Auto-generated header for grouped conv {variant_desc} kernels +#pragma once + +{includes} + +// Default launcher alias (uses first kernel) +{launcher_alias} +""" + header_path.write_text(content) + if kernel_headers: + log.info(f"Generated: {header_name} ({len(kernel_headers)} kernels)") + + # Generate registration header (following GEMM pattern) + self._generate_registration_header( + fwd_kernels, bwd_data_kernels, bwd_weight_kernels + ) + + def _generate_registration_header( + self, + fwd_kernels: List[str], + bwd_data_kernels: List[str], + bwd_weight_kernels: List[str], + ): + """Generate master registration header for all grouped conv kernels""" + # Scan wrapper directory for ALL wrapper files + all_wrappers = [] + for wrapper_path in self.wrapper_dir.glob( + "dispatcher_wrapper_grouped_conv_*.hpp" + ): + all_wrappers.append(wrapper_path.name) + + wrapper_includes = "\n".join(f'#include "{w}"' for w in sorted(all_wrappers)) + + # Generate registration calls + fwd_registrations = "\n ".join( + f"registry.register_kernel(generated::make_{k}(gfx_arch), priority);" + for k in sorted(fwd_kernels) + ) + bwd_data_registrations = "\n ".join( + f"registry.register_kernel(generated::make_{k}(gfx_arch), priority);" + for k in sorted(bwd_data_kernels) + ) + bwd_weight_registrations = "\n ".join( + f"registry.register_kernel(generated::make_{k}(gfx_arch), priority);" + for k in sorted(bwd_weight_kernels) + ) + + content = f"""// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +// Auto-generated master registration header for grouped conv kernels +#pragma once + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/grouped_conv_utils.hpp" + +{wrapper_includes} + +namespace ck_tile {{ +namespace dispatcher {{ + +using Priority = GroupedConvRegistry::Priority; + +inline void register_all_grouped_conv_fwd_kernels( + const std::string& gfx_arch = "gfx942", + Priority priority = Priority::Normal) +{{ + auto& registry = GroupedConvRegistry::instance(); + {fwd_registrations if fwd_registrations else "// No forward kernels"} +}} + +inline void register_all_grouped_conv_bwd_data_kernels( + const std::string& gfx_arch = "gfx942", + Priority priority = Priority::Normal) +{{ + auto& registry = GroupedConvRegistry::instance(); + {bwd_data_registrations if bwd_data_registrations else "// No backward data kernels"} +}} + +inline void register_all_grouped_conv_bwd_weight_kernels( + const std::string& gfx_arch = "gfx942", + Priority priority = Priority::Normal) +{{ + auto& registry = GroupedConvRegistry::instance(); + {bwd_weight_registrations if bwd_weight_registrations else "// No backward weight kernels"} +}} + +inline void register_all_grouped_conv_kernels( + const std::string& gfx_arch = "gfx942", + Priority priority = Priority::Normal) +{{ + register_all_grouped_conv_fwd_kernels(gfx_arch, priority); + register_all_grouped_conv_bwd_data_kernels(gfx_arch, priority); + register_all_grouped_conv_bwd_weight_kernels(gfx_arch, priority); +}} + +inline std::size_t get_grouped_conv_fwd_kernel_count() {{ return {len(fwd_kernels)}; }} +inline std::size_t get_grouped_conv_bwd_data_kernel_count() {{ return {len(bwd_data_kernels)}; }} +inline std::size_t get_grouped_conv_bwd_weight_kernel_count() {{ return {len(bwd_weight_kernels)}; }} +inline std::size_t get_grouped_conv_kernel_count() {{ return {len(fwd_kernels) + len(bwd_data_kernels) + len(bwd_weight_kernels)}; }} + +}} // namespace dispatcher +}} // namespace ck_tile +""" + reg_path = self.wrapper_dir / "register_all_grouped_conv_kernels.hpp" + reg_path.write_text(content) + log.info(f"Generated registration header: {reg_path}") + + +# ============================================================================ +# CLI +# ============================================================================ + + +def main(): + parser = argparse.ArgumentParser( + description="Unified Grouped Convolution Code Generator" + ) + parser.add_argument( + "--output", + "-o", + type=Path, + default=Path("build/generated_kernels"), + help="Output directory", + ) + parser.add_argument( + "--datatype", + "-d", + type=str, + nargs="+", + default=["fp16"], + choices=["fp16", "bf16", "fp32"], + help="Data types to generate", + ) + parser.add_argument( + "--variant", + "-v", + type=str, + nargs="+", + default=["forward"], + choices=["forward", "bwd_data", "bwd_weight"], + help="Grouped convolution variants", + ) + parser.add_argument( + "--ndim", + "-n", + type=int, + nargs="+", + default=[2], + choices=[1, 2, 3], + help="Spatial dimensions", + ) + parser.add_argument( + "--arch", + "-a", + type=str, + default="gfx942", + choices=["gfx90a", "gfx942", "gfx950", "gfx1201"], + help="Target GPU architecture", + ) + parser.add_argument("--verbose", action="store_true", help="Verbose output") + parser.add_argument( + "--list-configs", + action="store_true", + help="List configurations without generating", + ) + + # Individual kernel configuration (when not using predefined configs) + parser.add_argument("--tile-m", type=int, help="Block tile M dimension") + parser.add_argument("--tile-n", type=int, help="Block tile N dimension") + parser.add_argument("--tile-k", type=int, help="Block tile K dimension") + parser.add_argument("--warp-m", type=int, help="Wave distribution M") + parser.add_argument("--warp-n", type=int, help="Wave distribution N") + parser.add_argument("--warp-k", type=int, default=1, help="Wave distribution K") + parser.add_argument("--warp-tile-m", type=int, help="Warp tile M") + parser.add_argument("--warp-tile-n", type=int, help="Warp tile N") + parser.add_argument("--warp-tile-k", type=int, default=16, help="Warp tile K") + parser.add_argument( + "--pipeline", + type=str, + choices=["mem", "compv3", "compv4", "compv5"], + help="Pipeline type", + ) + parser.add_argument( + "--scheduler", + type=str, + choices=["intrawave", "interwave"], + help="Scheduler type", + ) + parser.add_argument( + "--epilogue", + type=str, + default="cshuffle", + choices=["cshuffle", "default"], + help="Epilogue type", + ) + parser.add_argument("--pad-m", type=bool, default=True, help="Pad M dimension") + parser.add_argument("--pad-n", type=bool, default=True, help="Pad N dimension") + parser.add_argument("--pad-k", type=bool, default=True, help="Pad K dimension") + parser.add_argument("--vector-a", type=int, default=4, help="Vector size A") + parser.add_argument("--vector-b", type=int, default=8, help="Vector size B") + parser.add_argument("--vector-c", type=int, default=8, help="Vector size C") + parser.add_argument("--block-per-cu", type=int, default=1, help="Blocks per CU") + parser.add_argument("--num-wave-groups", type=int, default=1, help="Wave groups") + parser.add_argument( + "--num-groups-to-merge", type=int, default=1, help="Groups to merge" + ) + parser.add_argument( + "--double-smem-buffer", + type=str, + default=None, + help="Double SMEM buffer (true/false)", + ) + + args = parser.parse_args() + + if args.verbose: + logging.getLogger().setLevel(logging.DEBUG) + + # Map variant strings to enums + variant_map = { + "forward": GroupedConvVariant.FORWARD, + "bwd_data": GroupedConvVariant.BACKWARD_DATA, + "bwd_weight": GroupedConvVariant.BACKWARD_WEIGHT, + } + requested_variants = [variant_map[v] for v in args.variant] + + # Check if user specified custom configuration + custom_config = ( + args.tile_m is not None or args.tile_n is not None or args.pipeline is not None + ) + + if custom_config: + # Build custom config from CLI arguments + tile = TileConfig( + tile_m=args.tile_m or 128, + tile_n=args.tile_n or 128, + tile_k=args.tile_k or 64, + warp_m=args.warp_m or 2, + warp_n=args.warp_n or 2, + warp_k=args.warp_k or 1, + warp_tile_m=args.warp_tile_m or 32, + warp_tile_n=args.warp_tile_n or 32, + warp_tile_k=args.warp_tile_k or 16, + ) + pipeline = args.pipeline or "compv4" + # Determine double_smem_buffer: use CLI arg if given, else default based on pipeline + if args.double_smem_buffer is not None: + dsb = args.double_smem_buffer.lower() == "true" + else: + dsb = pipeline == "compv4" # compv4 requires double buffer + + trait = GroupedConvTraitConfig( + pipeline=pipeline, + scheduler=args.scheduler or "intrawave", + epilogue=args.epilogue or "cshuffle", + pad_m=args.pad_m, + pad_n=args.pad_n, + pad_k=args.pad_k, + double_smem_buffer=dsb, + num_groups_to_merge=args.num_groups_to_merge, + ) + config = GroupedConvKernelConfig( + tile=tile, + trait=trait, + variant=requested_variants[0] + if requested_variants + else GroupedConvVariant.FORWARD, + ndim_spatial=args.ndim[0] if args.ndim else 2, + arch=args.arch, + vector_size_a=args.vector_a, + vector_size_b=args.vector_b, + vector_size_c=args.vector_c, + block_per_cu=args.block_per_cu, + num_wave_groups=args.num_wave_groups, + ) + filtered_configs = [config] + else: + # Get predefined configurations for target arch with requested variants and ndims + filtered_configs = get_default_configs( + arch=args.arch, variants=requested_variants, ndims=args.ndim + ) + + if args.list_configs: + print(f"Grouped convolution configurations for {args.arch}:") + print(f" Datatypes: {args.datatype}") + print(f" Variants: {args.variant}") + print(f" Spatial dims: {args.ndim}") + print(f"\nConfigurations ({len(filtered_configs)}):") + for cfg in filtered_configs: + print(f" - {cfg.name('fp16')}") + print(f" Tile: {cfg.tile.tile_m}x{cfg.tile.tile_n}x{cfg.tile.tile_k}") + print(f" Warp: {cfg.tile.warp_m}x{cfg.tile.warp_n}x{cfg.tile.warp_k}") + print( + f" WarpTile: {cfg.tile.warp_tile_m}x{cfg.tile.warp_tile_n}x{cfg.tile.warp_tile_k}" + ) + print( + f" Pipeline: {cfg.trait.pipeline}, Epilogue: {cfg.trait.epilogue}, Scheduler: {cfg.trait.scheduler}" + ) + print( + f" Padding: M={cfg.trait.pad_m}, N={cfg.trait.pad_n}, K={cfg.trait.pad_k}" + ) + return + + # Generate + codegen = UnifiedGroupedConvCodegen( + output_dir=args.output, + gpu_target=args.arch, + enable_arch_filter=True, + ) + results = codegen.generate_all( + configs=filtered_configs, datatypes=args.datatype, parallel=True + ) + + print( + f"\nGenerated {len(results['kernels'])} grouped convolution kernel files " + f"for {args.arch} in {args.output}" + ) + if results["failed"]: + print(f" Failed: {len(results['failed'])}") + for err in results["failed"][:5]: + print(f" - {err}") + + +if __name__ == "__main__": + main() diff --git a/dispatcher/examples/CMakeLists.txt b/dispatcher/examples/CMakeLists.txt index bda8eb0372..ab094e90cf 100644 --- a/dispatcher/examples/CMakeLists.txt +++ b/dispatcher/examples/CMakeLists.txt @@ -187,7 +187,6 @@ function(add_gpu_example NAME SOURCE KERNEL_HEADER) if(HEADER_NAME STREQUAL "register_all_kernels.hpp") # Registration header - examples include it directly target_compile_options(${NAME} PRIVATE - -DGEMM_KERNEL_AVAILABLE=1 -mllvm -enable-noalias-to-md-conversion=0 -Wno-undefined-func-template -Wno-float-equal @@ -315,6 +314,7 @@ function(add_declarative_gpu_example NAME SOURCE) target_include_directories(${NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/../../include ${CMAKE_CURRENT_SOURCE_DIR}/../include + ${CMAKE_CURRENT_SOURCE_DIR}/../.. ${EXAMPLE_KERNEL_DIR} ${EXAMPLE_KERNEL_DIR}/dispatcher_wrappers ) @@ -322,7 +322,6 @@ function(add_declarative_gpu_example NAME SOURCE) # Force-include the generated registration header target_compile_options(${NAME} PRIVATE -include ${EXAMPLE_HEADER} - -DGEMM_KERNEL_AVAILABLE=1 -mllvm -enable-noalias-to-md-conversion=0 -Wno-undefined-func-template -Wno-float-equal @@ -345,6 +344,7 @@ add_declarative_gpu_example(gemm_03_benchmark_validation gemm/cpp/03_benchmark_v add_declarative_gpu_example(gemm_04_heuristics gemm/cpp/04_heuristics.cpp) add_declarative_gpu_example(gemm_05_json_export gemm/cpp/05_json_export.cpp) add_declarative_gpu_example(gemm_06_multi_registry gemm/cpp/06_multi_registry.cpp) +add_declarative_gpu_example(gemm_07_gfx950_minimal gemm/cpp/07_gfx950_minimal.cpp) # ML Heuristic example -- requires LightGBM shared library # Derive site-packages from active Python interpreter (respects virtualenvs) @@ -443,19 +443,79 @@ if(hip_FOUND) endif() add_dependencies(dispatcher_gemm_lib generate_gemm_fallback_kernel) +# ============================================================================= +# Grouped Convolution C++ Examples +# ============================================================================= + +add_declarative_gpu_example(grouped_conv_01_basic grouped_conv/cpp/01_basic_grouped_conv.cpp) +add_declarative_gpu_example(grouped_conv_02_all_dirs grouped_conv/cpp/02_all_directions.cpp) +add_declarative_gpu_example(grouped_conv_03_bench_val grouped_conv/cpp/03_benchmark_validation.cpp) +add_declarative_gpu_example(grouped_conv_04_registry_json grouped_conv/cpp/04_registry_json.cpp) +add_declarative_gpu_example(grouped_conv_05_bwd_data grouped_conv/cpp/05_bwd_data.cpp) +add_declarative_gpu_example(grouped_conv_06_bwd_weight grouped_conv/cpp/06_bwd_weight.cpp) +add_declarative_gpu_example(grouped_conv_07_benchmark grouped_conv/cpp/07_multi_tile_benchmark.cpp) + +# ============================================================================= +# Grouped Convolution Python Library - Multi-Kernel (fwd/bwd_data/bwd_weight x 2D/3D) +# ============================================================================= + +# Kernel output directory for the Python conv library +set(CONV_FALLBACK_KERNEL_DIR "${CMAKE_CURRENT_BINARY_DIR}/conv_python_fallback") +set(CONV_DISPATCH_HEADER "${CONV_FALLBACK_KERNEL_DIR}/conv_python_dispatch.hpp") + +# Generate ALL conv kernels (fwd/bwd_data/bwd_weight x 2D/3D x multiple tile configs) +# then create the dispatch header with 2D/3D aliases +add_custom_command( + OUTPUT ${CONV_DISPATCH_HEADER} + COMMAND ${CMAKE_COMMAND} -E make_directory ${CONV_FALLBACK_KERNEL_DIR} + COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../codegen/unified_grouped_conv_codegen.py + --variant forward bwd_data bwd_weight --ndim 2 3 + --datatype fp16 --arch ${GPU_TARGET} + --output ${CONV_FALLBACK_KERNEL_DIR} + COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../scripts/generate_conv_dispatch_header.py + --kernel-dir ${CONV_FALLBACK_KERNEL_DIR} + --output ${CONV_DISPATCH_HEADER} + COMMENT "Generating conv kernels (fwd/bwd_data/bwd_weight x 2D/3D) for Python library..." + VERBATIM +) + +add_custom_target(generate_conv_fallback_kernels DEPENDS ${CONV_DISPATCH_HEADER}) + +# Conv dynamic library for Python (all 6 kernel variants) +add_library(dispatcher_conv_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/../bindings/ctypes/conv_ctypes_lib.cpp) +target_link_libraries(dispatcher_conv_lib PRIVATE ck_tile_dispatcher) +target_include_directories(dispatcher_conv_lib PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../../include + ${CMAKE_CURRENT_SOURCE_DIR}/../include + ${CONV_FALLBACK_KERNEL_DIR} +) +target_compile_options(dispatcher_conv_lib PRIVATE + -include ${CONV_DISPATCH_HEADER} + -DGFX_ARCH="${GPU_TARGET}" + -mllvm -enable-noalias-to-md-conversion=0 + -Wno-undefined-func-template + -Wno-float-equal + --offload-compress +) +if(hip_FOUND) + target_link_libraries(dispatcher_conv_lib PRIVATE hip::device hip::host) +endif() +add_dependencies(dispatcher_conv_lib generate_conv_fallback_kernels) + message(STATUS "GEMM examples configured - kernels will be generated during 'make'") +message(STATUS "Grouped Conv examples configured - kernels will be generated during 'make'") # Convenience target to build all Python ctypes libraries add_custom_target(python_libs - DEPENDS dispatcher_gemm_lib - COMMENT "Building Python ctypes libraries (GEMM)" + DEPENDS dispatcher_gemm_lib dispatcher_conv_lib + COMMENT "Building Python ctypes libraries (GEMM + Conv)" ) # ============================================================================= # Per-Architecture Kernel Generation Targets # ============================================================================= -set(SUPPORTED_GPU_ARCHS gfx942 gfx90a gfx1100 gfx1030) +set(SUPPORTED_GPU_ARCHS gfx942 gfx950 gfx90a gfx1100 gfx1030) foreach(ARCH ${SUPPORTED_GPU_ARCHS}) # GEMM kernels for this arch diff --git a/dispatcher/examples/README.md b/dispatcher/examples/README.md index fdee9c3583..24bea821ba 100644 --- a/dispatcher/examples/README.md +++ b/dispatcher/examples/README.md @@ -1,8 +1,6 @@ # CK Tile Dispatcher Examples -Comprehensive examples for GEMM operations with GPU execution. - -> **Note**: Convolution examples have been moved to `ck-2/conv_archive/` for reference. +Comprehensive examples for GEMM and Grouped Convolution operations with GPU execution. --- @@ -60,11 +58,11 @@ python3 examples/gemm/python/08_heuristics.py ``` examples/ -├── gemm/ -│ ├── cpp/ # 6 C++ GEMM examples -│ └── python/ # 11 Python GEMM examples -│ -└── README.md +|---- gemm/ +| |---- cpp/ # 6 C++ GEMM examples +| +---- python/ # 11 Python GEMM examples +| ++---- README.md ``` --- @@ -201,10 +199,31 @@ rocminfo | grep "Name:" --- -## Archived Examples +## Grouped Convolution -Convolution examples have been archived to `ck-2/conv_archive/dispatcher/`: -- `examples/conv/cpp/` - 11 C++ convolution examples -- `examples/conv/python/` - 14 Python convolution examples +Grouped convolution support has been re-introduced with a unified infrastructure shared with GEMM. -See the archive for convolution functionality reference. +### Infrastructure + +The grouped convolution code generation, utilities, and build scripts are available: + +| Component | Location | +|-----------|----------| +| C++ Headers | `include/ck_tile/dispatcher/grouped_conv_*.hpp` | +| Python Codegen | `codegen/unified_grouped_conv_codegen.py` | +| Python Utils | `python/grouped_conv_utils.py` | +| Build Script | `scripts/compile_grouped_conv_examples.py` | + +### Building Grouped Conv Kernels + +```bash +# Generate grouped conv kernels +python3 codegen/unified_grouped_conv_codegen.py \ + --output-dir build/generated_kernels \ + --datatype fp16 --variant forward --ndim-spatial 2 + +# Compile a grouped conv example +python3 scripts/compile_grouped_conv_examples.py my_grouped_conv_example.cpp +``` + +See the [main README](../README.md#grouped-convolution-support) for more details. diff --git a/dispatcher/examples/gemm/cpp/02_multi_size.cpp b/dispatcher/examples/gemm/cpp/02_multi_size.cpp index 5e620209f4..ffd2858be4 100644 --- a/dispatcher/examples/gemm/cpp/02_multi_size.cpp +++ b/dispatcher/examples/gemm/cpp/02_multi_size.cpp @@ -21,9 +21,9 @@ * - pipeline: "compv3" -> 1 option (compv4 requires special handling) * - scheduler: "intrawave" -> 1 option * - * Raw expansion: 3 × 2 = 6 configs, but arch filter validates each: - * - tile_m must be divisible by (warp_m × warp_tile_m) - * - tile_n must be divisible by (warp_n × warp_tile_n) + * Raw expansion: 3 x 2 = 6 configs, but arch filter validates each: + * - tile_m must be divisible by (warp_m x warp_tile_m) + * - tile_n must be divisible by (warp_n x warp_tile_n) * - Some wave/warp combos invalid: (4,1,1)+(32,32,16), (1,4,1)+(32,32,16) * Result: 4 valid wildcard kernels + 1 explicit = 5 total * @@ -70,13 +70,13 @@ DECL_KERNEL_SET(multi_size_kernels, .add(Signature().dtype("fp16").layout("rcr"), Algorithm() .tile(64, 64, 64) - .wave(ANY_INT, ANY_INT, 1) // ANY_INT → (1,4,1), (2,2,1), (4,1,1) - .warp(-1, -1, -1) // -1 same as ANY_INT → (16,16,32), (32,32,16) - .pipeline("*") // "*" → valid pipelines - .scheduler("*") // "*" → valid schedulers + .wave(ANY_INT, ANY_INT, 1) // ANY_INT -> (1,4,1), (2,2,1), (4,1,1) + .warp(-1, -1, -1) // -1 same as ANY_INT -> (16,16,32), (32,32,16) + .pipeline("*") // "*" -> valid pipelines + .scheduler("*") // "*" -> valid schedulers .epilogue("cshuffle"), "gfx942")); -// Raw: 3×2=6, arch filter removes 2 invalid → 4 valid kernels +// Raw: 3x2=6, arch filter removes 2 invalid -> 4 valid kernels // ============================================================================= // MAIN @@ -116,8 +116,8 @@ int main(int argc, char* argv[]) .pipeline("*") -> expands to valid pipelines = 1 .scheduler("*") -> expands to valid schedulers = 1 - Expanded: 3 × 2 = 6 configs, but arch filter validates each: - - wave×warp must divide tile: (4,1,1)×(32,32,16) invalid for 64x64 + Expanded: 3 x 2 = 6 configs, but arch filter validates each: + - wave x warp must divide tile: (4,1,1)x(32,32,16) invalid for 64x64 - Result: 4 valid kernels from wildcard + 1 explicit = 5 total )"; diff --git a/dispatcher/examples/gemm/cpp/07_gfx950_minimal.cpp b/dispatcher/examples/gemm/cpp/07_gfx950_minimal.cpp new file mode 100644 index 0000000000..7e62ad2e4f --- /dev/null +++ b/dispatcher/examples/gemm/cpp/07_gfx950_minimal.cpp @@ -0,0 +1,191 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * Example 07: Minimal gfx950 (CDNA4 / MI350) GEMM + * + * Demonstrates the dispatcher working with gfx950-specific kernels: + * + * - fp16 GEMM with standard tile configs + * - fp8 GEMM with gfx950-extended warp tiles (16x16x128) + * - 160KB LDS: gfx950 doubles the LDS from 64KB to 160KB + * + * Build: cd dispatcher/build && cmake .. -DGPU_TARGETS=gfx950 && make gemm_07_gfx950_minimal + */ + +#include +#include +#include +#include + +#include "ck_tile/dispatcher.hpp" +#include "ck_tile/dispatcher/kernel_decl.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::backends; +using namespace ck_tile::dispatcher::utils; +using Signature = decl::Signature; +using Algorithm = decl::Algorithm; + +// ============================================================================= +// gfx950-targeted kernel declarations +// ============================================================================= + +DECL_KERNEL_SET(gfx950_gemm_kernels, + + // fp16 128x128x32 -- bread-and-butter config, works on all CDNA + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(128, 128, 32) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx950") + + // fp16 128x128x64 -- deeper K tile using more LDS + // LDS usage: 128*64*2 + 128*64*2 = 32768 bytes (fits 64KB, gfx950 has 160KB) + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(128, 128, 64) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx950") + + // fp16 64x64x32 -- small-tile variant for small problems + .add(Signature().dtype("fp16").layout("rcr"), + Algorithm() + .tile(64, 64, 32) + .wave(2, 2, 1) + .warp(16, 16, 32) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle"), + "gfx950")); + +// ============================================================================= +// MAIN +// ============================================================================= + +int main(int argc, char* argv[]) +{ + ExampleArgs args("Example 07: gfx950 Minimal GEMM", + "Demonstrates gfx950 (CDNA4 / MI350) dispatcher"); + args.add_flag("--list", "List registered kernels"); + args.add_flag("--list-verbose", "List registered kernels with full details"); + args.add_option("--M", "1024", "Problem M dimension"); + args.add_option("--N", "1024", "Problem N dimension"); + args.add_option("--K", "1024", "Problem K dimension"); + args.add_option("--arch", "gfx950", "GPU architecture (default: gfx950)"); + + if(!args.parse(argc, argv)) + return 0; + + std::string gfx_arch = args.get("--arch", "gfx950"); + + print_header("Example 07: gfx950 (CDNA4) Minimal GEMM"); + + // ========================================================================= + // Architecture info + // ========================================================================= + std::cout << "\ngfx950 (CDNA4 / MI350) highlights:\n"; + std::cout << " - 160KB LDS (up from 64KB on gfx942)\n"; + std::cout << " - Extended FP8 warp tiles: 16x16x128, 32x32x64\n"; + std::cout << " - Packed FP4 support (pk_fp4)\n"; + std::cout << " - Same warp configs as gfx942: [1,4,1], [2,2,1], [4,1,1]\n\n"; + + // ========================================================================= + // Register kernels + // ========================================================================= + std::cout << "Registering kernels for " << gfx_arch << "...\n"; + + Registry registry; + registry.set_name("gfx950_gemm"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + + std::cout << " Registered " << registry.size() << " kernel(s)\n"; + + if(args.has("--list") || args.has("--list-verbose")) + { + std::cout << "\n"; + print_registered_kernels(registry, std::cout, args.has("--list-verbose")); + return 0; + } + + if(registry.size() == 0) + { + std::cerr << "ERROR: No kernels registered for " << gfx_arch << "!\n"; + std::cerr << " Did you build with -DGPU_TARGETS=gfx950?\n"; + return 1; + } + + // ========================================================================= + // Create Dispatcher + // ========================================================================= + Dispatcher dispatcher(®istry); + + // ========================================================================= + // Setup Problem + // ========================================================================= + const int M = args.get_int("--M", 1024); + const int N = args.get_int("--N", 1024); + const int K = args.get_int("--K", 1024); + + std::cout << "\nProblem: " << M << " x " << N << " x " << K << "\n"; + + Problem problem(M, N, K); + + using DataType = ck_tile::fp16_t; + GpuBuffer a_dev(M * K); + GpuBuffer b_dev(K * N); + GpuBuffer c_dev(M * N); + + std::vector a_host(M * K, DataType(1.0f)); + std::vector b_host(K * N, DataType(1.0f)); + a_dev.copy_from_host(a_host.data()); + b_dev.copy_from_host(b_host.data()); + c_dev.zero(); + + // ========================================================================= + // Select and Run + // ========================================================================= + auto selected = dispatcher.select_kernel(problem); + if(!selected) + { + std::cerr << "ERROR: No suitable kernel found for " << M << "x" << N << "x" << K << "\n"; + return 1; + } + std::cout << " Selected: " << selected->get_name() << "\n"; + + float time_ms = dispatcher.run(a_dev.get(), b_dev.get(), c_dev.get(), problem, nullptr); + std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n"; + std::cout << " TFLOPS: " << std::setprecision(2) << calculate_tflops(M, N, K, time_ms) << "\n"; + + // ========================================================================= + // Verify + // ========================================================================= + std::cout << "\nVerification:\n"; + std::vector c_host(M * N); + c_dev.copy_to_host(c_host.data()); + + const float expected = static_cast(K); + int errors = 0; + for(int i = 0; i < std::min(M * N, 1024); ++i) + { + if(std::abs(static_cast(c_host[i]) - expected) > 0.01f * expected + 1.0f) + ++errors; + } + + bool passed = (errors == 0); + std::cout << " Expected value: " << expected << "\n"; + std::cout << " Errors (first 1024 elements): " << errors << "\n"; + std::cout << " Status: " << (passed ? "PASS" : "FAIL") << "\n"; + + print_separator(); + return passed ? 0 : 1; +} diff --git a/dispatcher/examples/gemm/cpp/README.md b/dispatcher/examples/gemm/cpp/README.md index 1d81a90a0e..79d60d1198 100644 --- a/dispatcher/examples/gemm/cpp/README.md +++ b/dispatcher/examples/gemm/cpp/README.md @@ -29,14 +29,14 @@ cd examples ## Examples -| Example | Description | Complexity | -|---------|-------------|------------| -| [01_basic_gemm.cpp](01_basic_gemm.cpp) | Basic GEMM with declarative API, autofill, autocorrect | ★☆☆☆☆ | -| [02_multi_size.cpp](02_multi_size.cpp) | Wildcard expansion for multiple configurations | ★★☆☆☆ | -| [03_benchmark_validation.cpp](03_benchmark_validation.cpp) | Performance benchmarking with CPU reference validation | ★★☆☆☆ | -| [04_heuristics.cpp](04_heuristics.cpp) | Heuristic-based kernel selection | ★★★☆☆ | -| [05_json_export.cpp](05_json_export.cpp) | Registry JSON export for external tools | ★★☆☆☆ | -| [06_multi_registry.cpp](06_multi_registry.cpp) | Multiple registries with named kernel sets | ★★★☆☆ | +| Example | Description | +|---------|-------------| +| [01_basic_gemm.cpp](01_basic_gemm.cpp) | Basic GEMM with declarative API, autofill, autocorrect | +| [02_multi_size.cpp](02_multi_size.cpp) | Wildcard expansion for multiple configurations | +| [03_benchmark_validation.cpp](03_benchmark_validation.cpp) | Performance benchmarking with CPU reference validation | +| [04_heuristics.cpp](04_heuristics.cpp) | Heuristic-based kernel selection | +| [05_json_export.cpp](05_json_export.cpp) | Registry JSON export for external tools | +| [06_multi_registry.cpp](06_multi_registry.cpp) | Multiple registries with named kernel sets | ## Example Details @@ -225,5 +225,5 @@ DECL_KERNEL_SET(my_kernels, ## Related Documentation - [Python GEMM Examples](../python/README.md) -- [Convolution Examples](../../conv/cpp/README.md) +- [C++ Headers (GEMM + Grouped Conv)](../../../include/ck_tile/dispatcher/README.md) - [Main Dispatcher README](../../../README.md) diff --git a/dispatcher/examples/gemm/python/01_basic_gemm.py b/dispatcher/examples/gemm/python/01_basic_gemm.py index 93a78d24d1..8c23da89e2 100644 --- a/dispatcher/examples/gemm/python/01_basic_gemm.py +++ b/dispatcher/examples/gemm/python/01_basic_gemm.py @@ -7,41 +7,37 @@ Example 01: Basic GEMM with Multiple Kernels Demonstrates: -1. Declaring multiple kernel configurations -2. Printing all registered kernels -3. Running each kernel and validating output +1. Building a Registry with multiple kernel configurations +2. Parallel JIT compilation via registry.build() +3. Running each kernel and validating output against NumPy reference 4. Comparing performance across kernels -Complexity: ★★☆☆☆ - Usage: python3 01_basic_gemm.py - python3 01_basic_gemm.py --help python3 01_basic_gemm.py --dtype bf16 python3 01_basic_gemm.py --size 2048 + python3 01_basic_gemm.py --num-kernels 4 + python3 01_basic_gemm.py --workers 4 """ import sys +import time import argparse from pathlib import Path from dataclasses import dataclass -from typing import List sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) import numpy as np from ctypes_utils import ( KernelConfig, - setup_gemm_dispatcher, - cleanup_gemm, - reset_for_example, + Registry, + detect_gpu_arch, ) @dataclass class KernelSpec: - """Specification for a kernel configuration""" - name: str tile_m: int tile_n: int @@ -50,80 +46,37 @@ class KernelSpec: scheduler: str = "intrawave" -# Define multiple kernel configurations to test (50+ kernels) KERNEL_SPECS = [ - # Small tiles - compv3 + # Small tiles KernelSpec("small_64x64_k32", 64, 64, 32, "compv3"), KernelSpec("small_64x64_k64", 64, 64, 64, "compv3"), - # Small tiles - compv4 KernelSpec("small_64x64_v4_k32", 64, 64, 32, "compv4"), - KernelSpec("small_64x64_v4_k64", 64, 64, 64, "compv4"), - # Medium tiles - compv3 + # Medium tiles KernelSpec("med_128x128_k32", 128, 128, 32, "compv3"), KernelSpec("med_128x128_k64", 128, 128, 64, "compv3"), - KernelSpec("med_128x128_k128", 128, 128, 128, "compv3"), - # Medium tiles - compv4 KernelSpec("med_128x128_v4_k32", 128, 128, 32, "compv4"), KernelSpec("med_128x128_v4_k64", 128, 128, 64, "compv4"), - KernelSpec("med_128x128_v4_k128", 128, 128, 128, "compv4"), - # Rectangular tiles - compv3 + # Rectangular tiles KernelSpec("rect_64x128_k32", 64, 128, 32, "compv3"), KernelSpec("rect_64x128_k64", 64, 128, 64, "compv3"), KernelSpec("rect_128x64_k32", 128, 64, 32, "compv3"), KernelSpec("rect_128x64_k64", 128, 64, 64, "compv3"), - # Rectangular tiles - compv4 KernelSpec("rect_64x128_v4_k32", 64, 128, 32, "compv4"), - KernelSpec("rect_64x128_v4_k64", 64, 128, 64, "compv4"), KernelSpec("rect_128x64_v4_k32", 128, 64, 32, "compv4"), - KernelSpec("rect_128x64_v4_k64", 128, 64, 64, "compv4"), - # Large tiles - compv3 + # Large tiles KernelSpec("large_256x128_k32", 256, 128, 32, "compv3"), - KernelSpec("large_256x128_k64", 256, 128, 64, "compv3"), KernelSpec("large_128x256_k32", 128, 256, 32, "compv3"), - KernelSpec("large_128x256_k64", 128, 256, 64, "compv3"), KernelSpec("large_256x256_k32", 256, 256, 32, "compv3"), - KernelSpec("large_256x256_k64", 256, 256, 64, "compv3"), - # Large tiles - compv4 KernelSpec("large_256x128_v4_k32", 256, 128, 32, "compv4"), - KernelSpec("large_256x128_v4_k64", 256, 128, 64, "compv4"), - KernelSpec("large_128x256_v4_k32", 128, 256, 32, "compv4"), - KernelSpec("large_128x256_v4_k64", 128, 256, 64, "compv4"), KernelSpec("large_256x256_v4_k32", 256, 256, 32, "compv4"), - KernelSpec("large_256x256_v4_k64", 256, 256, 64, "compv4"), - # Interwave scheduler variants - KernelSpec("int_64x64_k32", 64, 64, 32, "compv3", "interwave"), + # Interwave scheduler KernelSpec("int_128x128_k32", 128, 128, 32, "compv3", "interwave"), - KernelSpec("int_128x128_k64", 128, 128, 64, "compv3", "interwave"), KernelSpec("int_256x128_k32", 256, 128, 32, "compv3", "interwave"), - # More tile_k variations - compv3 - KernelSpec("med_128x128_k16", 128, 128, 16, "compv3"), - KernelSpec("rect_64x128_k16", 64, 128, 16, "compv3"), - KernelSpec("rect_128x64_k16", 128, 64, 16, "compv3"), - # More tile_k variations - compv4 - KernelSpec("med_128x128_v4_k16", 128, 128, 16, "compv4"), - KernelSpec("rect_64x128_v4_k16", 64, 128, 16, "compv4"), - KernelSpec("rect_128x64_v4_k16", 128, 64, 16, "compv4"), - # Additional rectangular - KernelSpec("rect_32x64_k32", 32, 64, 32, "compv3"), - KernelSpec("rect_64x32_k32", 64, 32, 32, "compv3"), - KernelSpec("rect_32x128_k32", 32, 128, 32, "compv3"), - KernelSpec("rect_128x32_k32", 128, 32, 32, "compv3"), - # Additional compv4 variants - KernelSpec("rect_32x64_v4_k32", 32, 64, 32, "compv4"), - KernelSpec("rect_64x32_v4_k32", 64, 32, 32, "compv4"), - KernelSpec("rect_32x128_v4_k32", 32, 128, 32, "compv4"), - KernelSpec("rect_128x32_v4_k32", 128, 32, 32, "compv4"), ] -def create_kernel_config(spec: KernelSpec, dtype: str, arch: str) -> KernelConfig: - """Create a KernelConfig from a spec""" - # Adjust warp tiles based on tile size - if spec.tile_m <= 64: - warp_m, warp_n = 16, 16 - else: - warp_m, warp_n = 32, 32 - +def spec_to_config(spec: KernelSpec, dtype: str, arch: str) -> KernelConfig: + warp_m, warp_n = (16, 16) if spec.tile_m <= 64 else (32, 32) return KernelConfig( dtype_a=dtype, dtype_b=dtype, @@ -148,180 +101,118 @@ def create_kernel_config(spec: KernelSpec, dtype: str, arch: str) -> KernelConfi ) -def print_kernel_table(specs: List[KernelSpec], dtype: str): - """Print a formatted table of kernel configurations""" - print("\n" + "=" * 70) - print(f" DECLARED KERNEL CONFIGURATIONS ({len(specs)} kernels)") - print("=" * 70) - print(f"\n {'#':<3} {'Name':<18} {'Tile':<14} {'Pipeline':<10} {'Scheduler':<12}") - print(" " + "-" * 68) - - for i, spec in enumerate(specs, 1): - tile = f"{spec.tile_m}x{spec.tile_n}x{spec.tile_k}" - print( - f" {i:<3} {spec.name:<18} {tile:<14} {spec.pipeline:<10} {spec.scheduler:<12}" - ) - - print(" " + "-" * 68) - print(f" Data type: {dtype}") - - def main(): - parser = argparse.ArgumentParser( - description="Basic GEMM Example with Multiple Kernels", - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=""" -Examples: - python3 01_basic_gemm.py # Default FP16 with 4 kernels - python3 01_basic_gemm.py --dtype bf16 # BF16 mode - python3 01_basic_gemm.py --size 2048 # Larger problem size - python3 01_basic_gemm.py --num-kernels 2 # Test only 2 kernels - """, - ) + parser = argparse.ArgumentParser(description="Basic GEMM with Multiple Kernels") + parser.add_argument("--dtype", default="fp16", choices=["fp16", "bf16"]) + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument("--size", type=int, default=512, help="Problem size MxNxK") + parser.add_argument("--num-kernels", type=int, default=0, help="0 = all") parser.add_argument( - "--dtype", - default="fp16", - choices=["fp16", "bf16", "fp32"], - help="Data type (default: fp16)", - ) - parser.add_argument( - "--arch", - default="gfx942", - help="Target architecture (default: gfx942)", - ) - parser.add_argument( - "--size", - type=int, - default=512, - help="Problem size MxNxK (default: 512)", - ) - parser.add_argument( - "--num-kernels", - type=int, - default=0, - help="Number of kernels to test (0 = all)", + "--workers", type=int, default=0, help="Max parallel JIT workers (0 = auto)" ) args = parser.parse_args() - reset_for_example() - print("=" * 70) print("Example 01: Basic GEMM with Multiple Kernels") print("=" * 70) - # Select kernels to test specs = KERNEL_SPECS[: args.num_kernels] if args.num_kernels > 0 else KERNEL_SPECS - # ========================================================================= - # Step 1: Print all kernel configurations - # ========================================================================= - print_kernel_table(specs, args.dtype) - - # ========================================================================= - # Step 2: Setup and test each kernel - # ========================================================================= - print("\n" + "=" * 70) - print(" RUNNING KERNELS") - print("=" * 70) - - np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32 - M, N, K = args.size, args.size, args.size - - results = [] - - print(f"\n Problem size: {M}x{N}x{K}\n") + # Step 1: Build registry print( - f" {'#':<3} {'Name':<18} {'Tile':<14} {'Time (ms)':>10} {'TFLOPS':>10} {'Max Err':>10} {'Status':<8}" + f"\n {len(specs)} kernel configurations, dtype={args.dtype}, arch={args.arch}" ) - print(" " + "-" * 78) - - for i, spec in enumerate(specs, 1): - # Create unique test data per kernel - np.random.seed(42 + i * 1000) - A = (np.random.randn(M, K) * 0.1).astype(np_dtype) - B = (np.random.randn(K, N) * 0.1).astype(np_dtype) - - # Create config and setup dispatcher - config = create_kernel_config(spec, args.dtype, args.arch) - - setup = setup_gemm_dispatcher( - config=config, - registry_name=f"kernel_{spec.name}", - verbose=False, - auto_rebuild=True, + print(f"\n {'#':<3} {'Name':<22} {'Tile':<14} {'Pipeline':<10} {'Scheduler':<12}") + print(" " + "-" * 64) + for i, s in enumerate(specs, 1): + print( + f" {i:<3} {s.name:<22} {s.tile_m}x{s.tile_n}x{s.tile_k:<6} {s.pipeline:<10} {s.scheduler:<12}" ) + reg = Registry(name="basic_gemm") + for s in specs: + reg.register_kernel(spec_to_config(s, args.dtype, args.arch)) + + # Step 2: Parallel JIT build via registry.build() + workers = args.workers if args.workers > 0 else None + print( + f"\n--- Parallel JIT Build ({len(specs)} kernels{f', workers={workers}' if workers else ''}) ---" + ) + + t0 = time.perf_counter() + setups = reg.build(verbose=False, max_workers=workers) + jit_build_s = time.perf_counter() - t0 + + built = sum(1 for s in setups if s.success) + print(f" Built: {built}/{len(specs)} kernels in {jit_build_s:.1f} s") + + if built == 0: + print(" ERROR: No kernels built") + return 1 + + # Step 3: Run each kernel and validate + print(f"\n--- Running Kernels (problem {args.size}x{args.size}x{args.size}) ---") + np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32 + M = N = K = args.size + + np.random.seed(42) + A = (np.random.randn(M, K) * 0.1).astype(np_dtype) + B = (np.random.randn(K, N) * 0.1).astype(np_dtype) + C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype(np_dtype) + + print( + f"\n {'#':<3} {'Name':<22} {'Tile':<14} {'Time(ms)':>10} {'TFLOPS':>10} {'MaxErr':>10} {'Status':<6}" + ) + print(" " + "-" * 80) + + results = [] + for i, (spec, setup) in enumerate(zip(specs, setups), 1): tile = f"{spec.tile_m}x{spec.tile_n}x{spec.tile_k}" if not setup.success: print( - f" {i:<3} {spec.name:<18} {tile:<14} {'N/A':>10} {'N/A':>10} {'N/A':>10} {'FAIL':<8}" + f" {i:<3} {spec.name:<22} {tile:<14} {'---':>10} {'---':>10} {'---':>10} {'SKIP':<6}" ) - results.append((spec.name, False, 0, 0, 0)) - cleanup_gemm() + results.append((spec.name, False, 0.0, 0.0, 0.0)) continue - dispatcher = setup.dispatcher - - # Check if size is supported - if not dispatcher.is_supported(M, N, K): + disp = setup.dispatcher + if not disp.is_supported(M, N, K): print( - f" {i:<3} {spec.name:<18} {tile:<14} {'N/A':>10} {'N/A':>10} {'N/A':>10} {'SKIP':<8}" + f" {i:<3} {spec.name:<22} {tile:<14} {'---':>10} {'---':>10} {'---':>10} {'SKIP':<6}" ) - results.append((spec.name, False, 0, 0, 0)) - cleanup_gemm() + results.append((spec.name, False, 0.0, 0.0, 0.0)) continue - # Run GEMM - result = dispatcher.run(A, B, M, N, K) - - if not result.success: + res = disp.run(A, B, M, N, K) + if not res.success: print( - f" {i:<3} {spec.name:<18} {tile:<14} {'N/A':>10} {'N/A':>10} {'N/A':>10} {'FAIL':<8}" + f" {i:<3} {spec.name:<22} {tile:<14} {'---':>10} {'---':>10} {'---':>10} {'FAIL':<6}" ) - results.append((spec.name, False, 0, 0, 0)) - cleanup_gemm() + results.append((spec.name, False, 0.0, 0.0, 0.0)) continue - # Validate against NumPy reference - C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype(np_dtype) - max_err = np.max(np.abs(result.output - C_ref)) - - # Check if within tolerance - passed = max_err < 1e-2 - status = "PASS" if passed else "FAIL" - + max_err = float(np.max(np.abs(res.output - C_ref))) + ok = max_err < 1e-2 + tag = "PASS" if ok else "FAIL" print( - f" {i:<3} {spec.name:<18} {tile:<14} {result.time_ms:>10.4f} {result.tflops:>10.2f} {max_err:>10.2e} {status:<8}" + f" {i:<3} {spec.name:<22} {tile:<14} {res.time_ms:>10.4f} {res.tflops:>10.2f} {max_err:>10.2e} {tag:<6}" ) - results.append((spec.name, passed, result.time_ms, result.tflops, max_err)) - - cleanup_gemm() - - # ========================================================================= - # Step 3: Summary - # ========================================================================= - print("\n" + "=" * 70) - print(" SUMMARY") - print("=" * 70) + results.append((spec.name, ok, res.time_ms, res.tflops, max_err)) + # Step 4: Summary passed = sum(1 for r in results if r[1]) failed = len(results) - passed + valid = [r for r in results if r[1]] - print(f"\n Results: {passed}/{len(results)} kernels passed") - print(f" Problem: {M}x{N}x{K}, dtype={args.dtype}") - - if results: - valid_results = [r for r in results if r[1]] - if valid_results: - best = max(valid_results, key=lambda x: x[3]) - print(f"\n Best kernel: {best[0]} ({best[3]:.2f} TFLOPS)") - - if failed == 0: - print("\n *** ALL KERNELS PASSED ***") - else: - print(f"\n *** {failed} KERNELS FAILED ***") - + print("\n" + "=" * 70) + print(f" Results: {passed}/{len(results)} passed") + print(f" Problem: {M}x{N}x{K}, dtype={args.dtype}") + print(f" JIT time: {jit_build_s:.1f} s (parallel)") + if valid: + best = max(valid, key=lambda x: x[3]) + print(f" Best: {best[0]} ({best[3]:.2f} TFLOPS)") + print(f" Status: {'PASS' if failed == 0 else 'FAIL'}") print("=" * 70) return 0 if failed == 0 else 1 diff --git a/dispatcher/examples/gemm/python/02_batch_gemm.py b/dispatcher/examples/gemm/python/02_batch_gemm.py index 039aba2790..745ec1c494 100644 --- a/dispatcher/examples/gemm/python/02_batch_gemm.py +++ b/dispatcher/examples/gemm/python/02_batch_gemm.py @@ -6,9 +6,7 @@ """ Example 02: Batch GEMM -Runs multiple GEMM operations with different sizes. - -Complexity: ★★☆☆☆ +Runs multiple GEMM operations with different sizes using JIT compilation. Usage: python3 02_batch_gemm.py @@ -25,9 +23,8 @@ import numpy as np from ctypes_utils import ( KernelConfig, - setup_gemm_dispatcher, - cleanup_gemm, - reset_for_example, + Registry, + detect_gpu_arch, ) @@ -55,20 +52,20 @@ Examples: help="Maximum problem size (default: 4096)", ) parser.add_argument( - "--arch", default="gfx942", help="Target architecture (default: gfx942)" + "--arch", + default=detect_gpu_arch(), + help="Target architecture (auto-detected from rocminfo)", ) args = parser.parse_args() - reset_for_example() - print("=" * 60) print("Example 02: Batch GEMM") print("=" * 60) # ========================================================================= - # Step 1: Setup dispatcher + # Step 1: JIT build dispatcher # ========================================================================= - print("\nStep 1: Setup Dispatcher") + print("\nStep 1: JIT Build Dispatcher") config = KernelConfig( dtype_a=args.dtype, @@ -80,19 +77,22 @@ Examples: gfx_arch=args.arch, ) - setup = setup_gemm_dispatcher(config, registry_name="batch_gemm", verbose=True) - if not setup.success: - print(f" ERROR: {setup.error}") + reg = Registry(name="batch_gemm") + reg.register_kernel(config) + + setups = reg.build(verbose=True) + if not setups or not setups[0].success: + error = setups[0].error if setups else "No kernels built" + print(f" ERROR: {error}") return 1 - dispatcher = setup.dispatcher + dispatcher = setups[0].dispatcher # ========================================================================= # Step 2: Run batch of different sizes # ========================================================================= print("\nStep 2: Run Batch") - # Generate sizes up to max_size all_sizes = [ (256, 256, 256), (512, 512, 512), @@ -135,9 +135,6 @@ Examples: avg_tflops = (total_ops / 1e12) / (total_time / 1000) print(f"\n Total: {total_time:.2f} ms, Average: {avg_tflops:.2f} TFLOPS") - # Cleanup - cleanup_gemm() - print("\n" + "=" * 60) print("Batch GEMM complete!") print("=" * 60) diff --git a/dispatcher/examples/gemm/python/03_benchmark.py b/dispatcher/examples/gemm/python/03_benchmark.py index bec1b7e2fb..508b3f8b35 100644 --- a/dispatcher/examples/gemm/python/03_benchmark.py +++ b/dispatcher/examples/gemm/python/03_benchmark.py @@ -6,9 +6,8 @@ """ Example 03: Benchmark -Performance benchmarking with compute-optimized kernel configuration. - -Complexity: ★★★☆☆ +Performance benchmarking with compute-optimized kernel configuration +using JIT compilation. Usage: python3 03_benchmark.py @@ -26,9 +25,8 @@ import numpy as np from ctypes_utils import ( KernelConfig, - setup_gemm_dispatcher, - cleanup_gemm, - reset_for_example, + Registry, + detect_gpu_arch, ) @@ -63,20 +61,20 @@ Examples: "--iterations", type=int, default=10, help="Benchmark iterations (default: 10)" ) parser.add_argument( - "--arch", default="gfx942", help="Target architecture (default: gfx942)" + "--arch", + default=detect_gpu_arch(), + help="Target architecture (auto-detected from rocminfo)", ) args = parser.parse_args() - reset_for_example() - print("=" * 60) print("Example 03: Benchmark") print("=" * 60) # ========================================================================= - # Step 1: Setup dispatcher with compute-optimized config + # Step 1: JIT build dispatcher with compute-optimized config # ========================================================================= - print("\nStep 1: Setup Dispatcher") + print("\nStep 1: JIT Build Dispatcher") config = KernelConfig( dtype_a=args.dtype, @@ -90,12 +88,16 @@ Examples: gfx_arch=args.arch, ) - setup = setup_gemm_dispatcher(config, registry_name="benchmark", verbose=True) - if not setup.success: - print(f" ERROR: {setup.error}") + reg = Registry(name="benchmark") + reg.register_kernel(config) + + setups = reg.build(verbose=True) + if not setups or not setups[0].success: + error = setups[0].error if setups else "No kernels built" + print(f" ERROR: {error}") return 1 - dispatcher = setup.dispatcher + dispatcher = setups[0].dispatcher # ========================================================================= # Step 2: Benchmark @@ -130,11 +132,9 @@ Examples: A = np.random.randn(M, K).astype(np_dtype) * 0.1 B = np.random.randn(K, N).astype(np_dtype) * 0.1 - # Warmup for _ in range(args.warmup): dispatcher.run(A, B, M, N, K) - # Benchmark times = [] for _ in range(args.iterations): result = dispatcher.run(A, B, M, N, K) @@ -150,9 +150,6 @@ Examples: f" {M:>4}x{N:>4}x{K:<4} | {min_time:>10.4f} | {avg_time:>10.4f} | {tflops:>10.2f}" ) - # Cleanup - cleanup_gemm() - # Summary print("\n" + "=" * 60) print("Summary") diff --git a/dispatcher/examples/gemm/python/04_validation.py b/dispatcher/examples/gemm/python/04_validation.py index 2fe54c53f7..d56621c3c8 100644 --- a/dispatcher/examples/gemm/python/04_validation.py +++ b/dispatcher/examples/gemm/python/04_validation.py @@ -6,9 +6,7 @@ """ Example 04: Validation -Validates GPU GEMM against NumPy reference. - -Complexity: ★★★☆☆ +Validates GPU GEMM against NumPy reference using JIT compilation. Usage: python3 04_validation.py @@ -26,9 +24,8 @@ import numpy as np from ctypes_utils import ( KernelConfig, Validator, - setup_gemm_dispatcher, - cleanup_gemm, - reset_for_example, + Registry, + detect_gpu_arch, ) @@ -56,20 +53,20 @@ Examples: "--atol", type=float, default=1e-2, help="Absolute tolerance (default: 1e-2)" ) parser.add_argument( - "--arch", default="gfx942", help="Target architecture (default: gfx942)" + "--arch", + default=detect_gpu_arch(), + help="Target architecture (auto-detected from rocminfo)", ) args = parser.parse_args() - reset_for_example() - print("=" * 60) print("Example 04: Validation") print("=" * 60) # ========================================================================= - # Step 1: Setup dispatcher + # Step 1: JIT build dispatcher # ========================================================================= - print("\nStep 1: Setup Dispatcher") + print("\nStep 1: JIT Build Dispatcher") config = KernelConfig( dtype_a=args.dtype, @@ -81,12 +78,16 @@ Examples: gfx_arch=args.arch, ) - setup = setup_gemm_dispatcher(config, registry_name="validation", verbose=True) - if not setup.success: - print(f" ERROR: {setup.error}") + reg = Registry(name="validation") + reg.register_kernel(config) + + setups = reg.build(verbose=True) + if not setups or not setups[0].success: + error = setups[0].error if setups else "No kernels built" + print(f" ERROR: {error}") return 1 - dispatcher = setup.dispatcher + dispatcher = setups[0].dispatcher # ========================================================================= # Step 2: Run validation tests @@ -139,9 +140,6 @@ Examples: print(f" {name:<15} | {M}x{N}x{K:<5} | {max_err:>10.2e} | FAILED") failed += 1 - # Cleanup - cleanup_gemm() - # Summary print("\n" + "=" * 60) total = passed + failed diff --git a/dispatcher/examples/gemm/python/05_numpy_integration.py b/dispatcher/examples/gemm/python/05_numpy_integration.py index 493ce46d22..b0af5fa700 100644 --- a/dispatcher/examples/gemm/python/05_numpy_integration.py +++ b/dispatcher/examples/gemm/python/05_numpy_integration.py @@ -8,7 +8,6 @@ Example 05: NumPy Integration Shows how to create a GPU-accelerated matmul wrapper. -Complexity: ★★☆☆☆ Usage: python3 05_numpy_integration.py @@ -29,6 +28,7 @@ from ctypes_utils import ( setup_gemm_dispatcher, cleanup_gemm, reset_for_example, + detect_gpu_arch, ) @@ -70,7 +70,9 @@ Examples: help="Data type (default: fp16)", ) parser.add_argument( - "--arch", default="gfx942", help="Target architecture (default: gfx942)" + "--arch", + default=detect_gpu_arch(), + help="Target architecture (auto-detected from rocminfo)", ) args = parser.parse_args() diff --git a/dispatcher/examples/gemm/python/06_json_export.py b/dispatcher/examples/gemm/python/06_json_export.py index 9e062e507b..780032ce06 100644 --- a/dispatcher/examples/gemm/python/06_json_export.py +++ b/dispatcher/examples/gemm/python/06_json_export.py @@ -8,7 +8,6 @@ Example 06: JSON Export Exports registry configuration to JSON. -Complexity: ★★☆☆☆ Usage: python3 06_json_export.py @@ -28,6 +27,7 @@ from ctypes_utils import ( setup_gemm_dispatcher, cleanup_gemm, reset_for_example, + detect_gpu_arch, ) @@ -54,7 +54,9 @@ Examples: help="Data type (default: fp16)", ) parser.add_argument( - "--arch", default="gfx942", help="Target architecture (default: gfx942)" + "--arch", + default=detect_gpu_arch(), + help="Target architecture (auto-detected from rocminfo)", ) args = parser.parse_args() diff --git a/dispatcher/examples/gemm/python/07_stress_test.py b/dispatcher/examples/gemm/python/07_stress_test.py index 8160030631..620e66eeaf 100644 --- a/dispatcher/examples/gemm/python/07_stress_test.py +++ b/dispatcher/examples/gemm/python/07_stress_test.py @@ -18,7 +18,6 @@ This tests: - Multiple data types (fp16, bf16) - Different schedulers (intrawave, interwave) -Complexity: ★★★★☆ Usage: python3 07_stress_test.py @@ -43,6 +42,7 @@ from ctypes_utils import ( cleanup_gemm, reset_for_example, Validator, + detect_gpu_arch, ) @@ -413,8 +413,8 @@ Examples: ) parser.add_argument( "--arch", - default="gfx942", - help="Target architecture (default: gfx942)", + default=detect_gpu_arch(), + help="Target architecture (auto-detected from rocminfo, override with --arch gfxNNN)", ) args = parser.parse_args() diff --git a/dispatcher/examples/gemm/python/08_heuristics.py b/dispatcher/examples/gemm/python/08_heuristics.py index e2763c0513..acbf1b3ae0 100644 --- a/dispatcher/examples/gemm/python/08_heuristics.py +++ b/dispatcher/examples/gemm/python/08_heuristics.py @@ -19,7 +19,6 @@ Heuristic strategies: - Memory-bound: Optimize memory access for bandwidth-limited cases - Latency-focused: Minimize kernel launch overhead for small problems -Complexity: ★★★★☆ Usage: python3 08_heuristics.py @@ -43,6 +42,7 @@ from ctypes_utils import ( setup_gemm_dispatcher, cleanup_gemm, reset_for_example, + detect_gpu_arch, ) @@ -561,8 +561,8 @@ Examples: ) parser.add_argument( "--arch", - default="gfx942", - help="Target architecture (default: gfx942)", + default=detect_gpu_arch(), + help="Target architecture (auto-detected from rocminfo, override with --arch gfxNNN)", ) args = parser.parse_args() diff --git a/dispatcher/examples/gemm/python/09_multi_registry.py b/dispatcher/examples/gemm/python/09_multi_registry.py index 97cbce3497..5d9af239d4 100644 --- a/dispatcher/examples/gemm/python/09_multi_registry.py +++ b/dispatcher/examples/gemm/python/09_multi_registry.py @@ -8,7 +8,6 @@ Example 09: Multiple Registries Demonstrates multiple registries for different optimization targets. -Complexity: ★★★★★ Usage: python3 09_multi_registry.py @@ -30,6 +29,7 @@ from ctypes_utils import ( setup_gemm_dispatcher, cleanup_gemm, reset_for_example, + detect_gpu_arch, ) @@ -50,7 +50,9 @@ Examples: help="Data type (default: fp16)", ) parser.add_argument( - "--arch", default="gfx942", help="Target architecture (default: gfx942)" + "--arch", + default=detect_gpu_arch(), + help="Target architecture (auto-detected from rocminfo)", ) args = parser.parse_args() diff --git a/dispatcher/examples/gemm/python/10_advanced_benchmark.py b/dispatcher/examples/gemm/python/10_advanced_benchmark.py index e16e4e271f..b1462478d0 100644 --- a/dispatcher/examples/gemm/python/10_advanced_benchmark.py +++ b/dispatcher/examples/gemm/python/10_advanced_benchmark.py @@ -33,6 +33,7 @@ from ctypes_utils import ( setup_gemm_dispatcher, cleanup_gemm, reset_for_example, + detect_gpu_arch, ) @@ -69,7 +70,11 @@ def parse_args(): # Kernel configuration parser.add_argument("--dtype", default="fp16", help="Data type") parser.add_argument("--pipeline", default="compv4", help="Pipeline type") - parser.add_argument("--arch", default="gfx942", help="GPU architecture") + parser.add_argument( + "--arch", + default=detect_gpu_arch(), + help="GPU architecture (auto-detected from rocminfo)", + ) return parser.parse_args() diff --git a/dispatcher/examples/gemm/python/11_json_import.py b/dispatcher/examples/gemm/python/11_json_import.py index 06743af406..d19395e553 100644 --- a/dispatcher/examples/gemm/python/11_json_import.py +++ b/dispatcher/examples/gemm/python/11_json_import.py @@ -15,7 +15,6 @@ Key Features: - Use arch_filter validation on loaded configs - Export to C++ DECL_KERNEL_SET format -Complexity: ★★★☆☆ Usage: python3 11_json_import.py @@ -45,6 +44,7 @@ from ctypes_utils import ( # noqa: E402 cleanup_gemm, reset_for_example, validate_kernel_config, + detect_gpu_arch, ) # Sample JSON configuration (embedded for demonstration) @@ -141,8 +141,8 @@ Examples: ) parser.add_argument( "--arch", - default="gfx942", - help="Target GPU architecture (default: gfx942)", + default=detect_gpu_arch(), + help="Target GPU architecture (auto-detected from rocminfo, override with --arch gfxNNN)", ) args = parser.parse_args() @@ -236,13 +236,13 @@ Examples: else: invalid_count += 1 if invalid_count <= 3: # Show first 3 invalid - print(f"\n ✗ Invalid: {config.kernel_name()}") + print(f"\n FAIL Invalid: {config.kernel_name()}") for error in result.errors: print(f" Error: {error}") print("\n Validation Summary:") - print(f" ✓ Valid: {valid_count}") - print(f" ✗ Invalid: {invalid_count}") + print(f" OK Valid: {valid_count}") + print(f" FAIL Invalid: {invalid_count}") print(f" Total: {len(configs)}") # ========================================================================= @@ -275,12 +275,12 @@ Examples: disp_config, registry_name="json_import", verbose=False ) if setup.success: - print(" ✓ Dispatcher setup successful") + print(" OK Dispatcher setup successful") print( f" Kernel header: {setup.kernel_header.name if setup.kernel_header else 'N/A'}" ) else: - print(f" ⚠ Dispatcher setup: {setup.error}") + print(f" WARNING Dispatcher setup: {setup.error}") print(" (This is expected if kernels aren't generated)") # ========================================================================= diff --git a/dispatcher/examples/gemm/python/README.md b/dispatcher/examples/gemm/python/README.md index 0a83f3533f..07757b951b 100644 --- a/dispatcher/examples/gemm/python/README.md +++ b/dispatcher/examples/gemm/python/README.md @@ -295,5 +295,5 @@ Compilation time scales roughly linearly with kernel count. ## Related Documentation - [C++ GEMM Examples](../cpp/README.md) -- [Python Conv Examples](../../conv/python/README.md) +- [Python Utilities](../../../python/README.md) - [Main Dispatcher README](../../../README.md) diff --git a/dispatcher/examples/grouped_conv/cpp/01_basic_grouped_conv.cpp b/dispatcher/examples/grouped_conv/cpp/01_basic_grouped_conv.cpp new file mode 100644 index 0000000000..b503129c57 --- /dev/null +++ b/dispatcher/examples/grouped_conv/cpp/01_basic_grouped_conv.cpp @@ -0,0 +1,203 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// Example 01: Basic Grouped Convolution +// +// Demonstrates three declaration patterns (mirrors GEMM 01): +// 1. AUTOFILL - tile + pipeline only, wave/warp auto-filled +// 2. AUTOCORRECT - invalid wave(1,1,1) corrected to valid config +// 3. FULL - all parameters explicit (matches validated gfx942 config) +// +// Then runs the forward convolution on GPU and verifies output. +// +// Build: cd dispatcher/build && cmake .. && make grouped_conv_01_basic + +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/convolution_parameter.hpp" +#include "ck_tile/ops/grouped_convolution.hpp" + +#include "ck_tile/dispatcher/grouped_conv_utils.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::grouped_conv_utils; +using GroupedConvSig = grouped_conv_decl::GroupedConvSignature; +using GroupedConvAlgo = grouped_conv_decl::GroupedConvAlgorithm; + +using InDataType = ck_tile::half_t; +using WeiDataType = ck_tile::half_t; +using OutDataType = ck_tile::half_t; + +// Three declaration patterns -- codegen auto-fills/auto-corrects as needed +DECL_GROUPED_CONV_KERNEL_SET( + basic_conv_kernels, + // Pattern 1: AUTOFILL - only tile + pipeline, rest auto-filled + .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2), + GroupedConvAlgo().tile(1, 128, 128).pipeline("compv4").scheduler("intrawave"), + "gfx950") + // Pattern 2: AUTOCORRECT - wave(1,1,1) invalid, corrected to (1,4,1) + .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2), + GroupedConvAlgo() + .tile(1, 64, 64) + .wave(1, 1, 1) + .warp(16, 16, 32) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle") + .vector_sizes(4, 8, 8), + "gfx950") + // Pattern 3: FULL - all parameters explicit (validated config) + .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2), + GroupedConvAlgo() + .tile(1, 128, 128) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle") + .vector_sizes(4, 8, 8) + .block_per_cu(1), + "gfx950")); + +int main(int argc, char* argv[]) +{ + utils::ExampleArgs args("Example 01: Basic Grouped Convolution", + "Declaration patterns + GPU execution"); + args.add_option("--arch", "gfx950", "GPU architecture"); + args.add_option("--size", "14", "Spatial size (H=W)"); + args.add_option("-n", "1", "Batch size"); + args.add_option("-g", "1", "Groups"); + args.add_option("-c", "64", "Input channels C"); + args.add_option("-k", "128", "Output channels K"); + + if(!args.parse(argc, argv)) + return 0; + + utils::print_header("Example 01: Basic Grouped Convolution"); + + std::string gfx_arch = args.get("--arch", "gfx950"); + int N = args.get_int("-n", 1); + int G = args.get_int("-g", 1); + int C = args.get_int("-c", 64); + int K = args.get_int("-k", 128); + int HW = args.get_int("--size", 14); + int Y = 3, X = 3; + + // Step 1: Show declared kernel sets + std::cout << "\nStep 1: Declared Kernel Sets\n"; + GroupedConvKernelSetRegistry::instance().print(); + + // Step 2: Register kernels + std::cout << "\nStep 2: Register Kernels\n"; + GroupedConvRegistry registry; + registry.set_name("basic_conv"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << " Registered " << registry.size() << " kernel(s)\n"; + + // Step 3: Create dispatcher + std::cout << "\nStep 3: Create Dispatcher\n"; + GroupedConvDispatcher dispatcher(®istry); + + // Step 4: Build problem using CK Tile ConvParam + std::cout << "\nStep 4: Problem\n"; + auto problem = create_grouped_conv2d_problem(N, C, K, HW, HW, Y, X, 1, 1); + problem.op = GroupedConvOp::Forward; + print_grouped_conv_problem(problem); + + ck_tile::conv::ConvParam conv_param{ + 2, + static_cast(G), + static_cast(N), + static_cast(K), + static_cast(C), + {static_cast(Y), static_cast(X)}, + {static_cast(HW), static_cast(HW)}, + {1, 1}, + {1, 1}, + {1, 1}, + {1, 1}}; + + using InLayout = ck_tile::tensor_layout::convolution::NHWGC; + using WeiLayout = ck_tile::tensor_layout::convolution::GKYXC; + using OutLayout = ck_tile::tensor_layout::convolution::NHWGK; + + auto in_desc = + ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + auto wei_desc = + ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + auto out_desc = + ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + ck_tile::HostTensor input_host(in_desc); + ck_tile::HostTensor weight_host(wei_desc); + ck_tile::HostTensor output_host(out_desc); + + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(input_host); + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(weight_host); + + ck_tile::DeviceMem input_dev(input_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem weight_dev(weight_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem output_dev(output_host.get_element_space_size_in_bytes()); + + input_dev.ToDevice(input_host.data()); + weight_dev.ToDevice(weight_host.data()); + + // Step 5: Select and run + std::cout << "\nStep 5: Select and Run\n"; + + auto* selected = dispatcher.select_kernel(problem); + if(!selected) + { + std::cerr << " ERROR: No kernel found for problem!\n"; + return 1; + } + std::cout << " Selected: " << selected->name() << "\n"; + + float time_ms = dispatcher.run(input_dev.GetDeviceBuffer(), + weight_dev.GetDeviceBuffer(), + output_dev.GetDeviceBuffer(), + problem, + nullptr); + + double tflops = calculate_conv_tflops(problem, time_ms); + std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n"; + std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n"; + + // Step 6: Verify + std::cout << "\nStep 6: Verify\n"; + output_dev.FromDevice(output_host.data()); + + size_t total = output_host.get_element_space_size(); + size_t nonzero = 0; + double checksum = 0.0; + for(size_t i = 0; i < total; ++i) + { + float v = static_cast(output_host.data()[i]); + if(v != 0.0f) + ++nonzero; + checksum += v; + } + + bool passed = nonzero > 0; + std::cout << " Output elements: " << total << "\n"; + std::cout << " Non-zero: " << nonzero << "/" << total + << (nonzero > 0 ? " (kernel produced output)" : " WARNING: all zeros!") << "\n"; + std::cout << " Checksum: " << std::fixed << std::setprecision(2) << checksum << "\n"; + std::cout << " Status: " << (passed ? "PASS" : "FAIL") << "\n"; + + utils::print_separator(); + std::cout << "DECLARATION PATTERNS:\n"; + std::cout << " 1. AUTOFILL: tile + pipeline only, wave/warp auto-filled\n"; + std::cout << " 2. AUTOCORRECT: invalid wave(1,1,1) corrected\n"; + std::cout << " 3. FULL: all parameters explicit\n"; + utils::print_separator(); + + return passed ? 0 : 1; +} diff --git a/dispatcher/examples/grouped_conv/cpp/02_all_directions.cpp b/dispatcher/examples/grouped_conv/cpp/02_all_directions.cpp new file mode 100644 index 0000000000..a2f2b9d560 --- /dev/null +++ b/dispatcher/examples/grouped_conv/cpp/02_all_directions.cpp @@ -0,0 +1,216 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// Example 02: All Convolution Directions +// +// Forward, backward-data, and backward-weight for 2D convolution, +// each executed on GPU with non-zero verification. +// +// Build: cd dispatcher/build && cmake .. && make grouped_conv_02_all_dirs + +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/convolution_parameter.hpp" +#include "ck_tile/ops/grouped_convolution.hpp" + +#include "ck_tile/dispatcher/grouped_conv_utils.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::grouped_conv_utils; +using GroupedConvSig = grouped_conv_decl::GroupedConvSignature; +using GroupedConvAlgo = grouped_conv_decl::GroupedConvAlgorithm; + +using InDataType = ck_tile::half_t; +using WeiDataType = ck_tile::half_t; +using OutDataType = ck_tile::half_t; + +DECL_GROUPED_CONV_KERNEL_SET( + conv_fwd_2d, + .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2), + GroupedConvAlgo().tile(1, 128, 128).pipeline("compv4").vector_sizes(4, 8, 8), + "gfx950")); + +DECL_GROUPED_CONV_KERNEL_SET( + conv_bwdd_2d, + .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("bwd_data").dims(2), + GroupedConvAlgo().tile(1, 128, 128).pipeline("compv3").vector_sizes(4, 8, 8), + "gfx950")); + +DECL_GROUPED_CONV_KERNEL_SET( + conv_bwdw_2d, + .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("bwd_weight").dims(2), + GroupedConvAlgo() + .tile(1, 128, 128) + .pipeline("compv3") + .memory_op("atomic_add") + .vector_sizes(4, 8, 8), + "gfx950")); + +int main(int argc, char* argv[]) +{ + utils::ExampleArgs args("Example 02: All Convolution Directions", + "Forward/BwdData/BwdWeight with GPU execution and verification"); + args.add_option("--arch", "gfx950", "GPU architecture"); + + if(!args.parse(argc, argv)) + return 0; + + utils::print_header("Example 02: All Convolution Directions"); + + std::string gfx_arch = args.get("--arch", "gfx950"); + + GroupedConvRegistry registry; + registry.set_name("all_directions"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << " Registered " << registry.size() << " kernel(s)\n"; + + GroupedConvDispatcher dispatcher(®istry); + + const int N = 1, G = 1, C = 64, K = 128, Hi = 14, Wi = 14, Y = 3, X = 3; + + ck_tile::conv::ConvParam conv_param{ + 2, + static_cast(G), + static_cast(N), + static_cast(K), + static_cast(C), + {static_cast(Y), static_cast(X)}, + {static_cast(Hi), static_cast(Wi)}, + {1, 1}, + {1, 1}, + {1, 1}, + {1, 1}}; + + using InLayout = ck_tile::tensor_layout::convolution::NHWGC; + using WeiLayout = ck_tile::tensor_layout::convolution::GKYXC; + using OutLayout = ck_tile::tensor_layout::convolution::NHWGK; + + auto in_desc = + ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + auto wei_desc = + ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + auto out_desc = + ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + ck_tile::HostTensor input(in_desc); + ck_tile::HostTensor weight(wei_desc); + ck_tile::HostTensor output(out_desc); + + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(input); + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(weight); + + ck_tile::DeviceMem input_dev(input.get_element_space_size_in_bytes()); + ck_tile::DeviceMem weight_dev(weight.get_element_space_size_in_bytes()); + ck_tile::DeviceMem output_dev(output.get_element_space_size_in_bytes()); + + input_dev.ToDevice(input.data()); + weight_dev.ToDevice(weight.data()); + + std::cout << "\n " << std::left << std::setw(12) << "Direction" << std::right << std::setw(10) + << "Time(ms)" << std::setw(10) << "TFLOPS" << std::setw(14) << "NonZero" + << std::setw(10) << "Status" << "\n"; + std::cout << " " << std::string(56, '-') << "\n"; + + bool all_pass = true; + + auto print_result = + [](const char* label, float time_ms, double tflops, size_t nz, size_t total, bool ok) { + std::cout << " " << std::left << std::setw(12) << label << std::right << std::fixed + << std::setprecision(4) << std::setw(10) << time_ms << std::setprecision(2) + << std::setw(10) << tflops << std::setw(14) + << (std::to_string(nz) + "/" + std::to_string(total)) << std::setw(10) + << (ok ? "OK" : "FAIL") << "\n"; + }; + + // Forward: run(X, W, Y) + { + auto problem = + create_grouped_conv2d_problem(N, C, K, Hi, Wi, Y, X, 1, 1, GroupedConvOp::Forward); + float time_ms = dispatcher.run(input_dev.GetDeviceBuffer(), + weight_dev.GetDeviceBuffer(), + output_dev.GetDeviceBuffer(), + problem, + nullptr); + output_dev.FromDevice(output.data()); + size_t nz = 0; + for(size_t i = 0; i < output.get_element_space_size(); ++i) + if(static_cast(output.data()[i]) != 0.0f) + ++nz; + bool ok = nz > 0; + print_result("forward", + time_ms, + calculate_conv_tflops(problem, time_ms), + nz, + output.get_element_space_size(), + ok); + if(!ok) + all_pass = false; + } + + // Backward Data: run(dY, W, dX) + { + auto problem = + create_grouped_conv2d_problem(N, C, K, Hi, Wi, Y, X, 1, 1, GroupedConvOp::BackwardData); + ck_tile::HostTensor dx_host(in_desc); + ck_tile::DeviceMem dx_dev(dx_host.get_element_space_size_in_bytes()); + float time_ms = dispatcher.run(output_dev.GetDeviceBuffer(), // dY (from forward pass) + weight_dev.GetDeviceBuffer(), // W + dx_dev.GetDeviceBuffer(), // dX (output) + problem, + nullptr); + dx_dev.FromDevice(dx_host.data()); + size_t nz = 0; + for(size_t i = 0; i < dx_host.get_element_space_size(); ++i) + if(static_cast(dx_host.data()[i]) != 0.0f) + ++nz; + bool ok = nz > 0; + print_result("bwd_data", + time_ms, + calculate_conv_tflops(problem, time_ms), + nz, + dx_host.get_element_space_size(), + ok); + if(!ok) + all_pass = false; + } + + // Backward Weight: run(X, dY, dW) + { + auto problem = create_grouped_conv2d_problem( + N, C, K, Hi, Wi, Y, X, 1, 1, GroupedConvOp::BackwardWeight); + ck_tile::HostTensor dw_host(wei_desc); + ck_tile::DeviceMem dw_dev(dw_host.get_element_space_size_in_bytes()); + float time_ms = dispatcher.run(input_dev.GetDeviceBuffer(), // X + output_dev.GetDeviceBuffer(), // dY + dw_dev.GetDeviceBuffer(), // dW (output) + problem, + nullptr); + dw_dev.FromDevice(dw_host.data()); + size_t nz = 0; + for(size_t i = 0; i < dw_host.get_element_space_size(); ++i) + if(static_cast(dw_host.data()[i]) != 0.0f) + ++nz; + bool ok = nz > 0; + print_result("bwd_weight", + time_ms, + calculate_conv_tflops(problem, time_ms), + nz, + dw_host.get_element_space_size(), + ok); + if(!ok) + all_pass = false; + } + + utils::print_separator(); + std::cout << " Status: " << (all_pass ? "PASS" : "FAIL") << "\n"; + utils::print_separator(); + + return all_pass ? 0 : 1; +} diff --git a/dispatcher/examples/grouped_conv/cpp/03_benchmark_validation.cpp b/dispatcher/examples/grouped_conv/cpp/03_benchmark_validation.cpp new file mode 100644 index 0000000000..12bd87d1a4 --- /dev/null +++ b/dispatcher/examples/grouped_conv/cpp/03_benchmark_validation.cpp @@ -0,0 +1,263 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// Example 03: Benchmark and CPU-Reference Validation +// +// Runs a 2D grouped conv forward kernel on the GPU via dispatcher.run() +// and compares against the CK Tile host reference implementation. +// Exposes warmup/repeat/log_level as CLI args (matches example 20 pattern). +// +// Build: cd dispatcher/build && cmake .. && make grouped_conv_03_bench_val + +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/convolution_parameter.hpp" +#include "ck_tile/ops/grouped_convolution.hpp" +#include "ck_tile/host/reference/reference_grouped_conv_fwd.hpp" + +#include "ck_tile/dispatcher/grouped_conv_utils.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::grouped_conv_utils; +using GroupedConvSig = grouped_conv_decl::GroupedConvSignature; +using GroupedConvAlgo = grouped_conv_decl::GroupedConvAlgorithm; + +using InDataType = ck_tile::half_t; +using WeiDataType = ck_tile::half_t; +using OutDataType = ck_tile::half_t; +using AccDataType = float; + +DECL_GROUPED_CONV_KERNEL_SET( + bench_kernels, + .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2), + GroupedConvAlgo().tile(1, 128, 128).pipeline("compv4").vector_sizes(4, 8, 8), + "gfx950") + .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2), + GroupedConvAlgo().tile(1, 64, 64).pipeline("compv3").vector_sizes(4, 8, 8), + "gfx950")); + +int main(int argc, char* argv[]) +{ + utils::ExampleArgs args("Example 03: Benchmark & Validation", + "GPU execution with CPU reference validation"); + args.add_option("-n", "1", "Batch size N"); + args.add_option("-g", "1", "Groups G"); + args.add_option("-c", "64", "Input channels C"); + args.add_option("-k", "128", "Output channels K"); + args.add_option("--size", "14", "Spatial size (H=W)"); + args.add_option("--warmup", "3", "Warmup iterations"); + args.add_option("--repeat", "10", "Benchmark iterations"); + args.add_option("--arch", "gfx950", "GPU architecture"); + args.add_flag("--no-verify", "Skip CPU validation"); + + if(!args.parse(argc, argv)) + return 0; + + utils::print_header("Example 03: Grouped Conv Benchmark & Validation"); + + int N = args.get_int("-n", 1); + int G = args.get_int("-g", 1); + int C = args.get_int("-c", 64); + int K = args.get_int("-k", 128); + int Hi = args.get_int("--size", 14); + int Wi = Hi; + int Y = 3, X = 3; + int warmup = args.get_int("--warmup", 3); + int repeat = args.get_int("--repeat", 10); + bool verify = !args.has("--no-verify"); + std::string gfx_arch = args.get("--arch", "gfx950"); + + std::cout << "\nProblem: N=" << N << " G=" << G << " C=" << C << " K=" << K << " Hi=" << Hi + << " Wi=" << Wi << " Y=" << Y << " X=" << X << "\n"; + std::cout << "Benchmark: warmup=" << warmup << " repeat=" << repeat << "\n"; + + // Step 1: Setup tensors using CK Tile descriptors + std::cout << "\nStep 1: Setup tensors\n"; + + ck_tile::conv::ConvParam conv_param{ + 2, + static_cast(G), + static_cast(N), + static_cast(K), + static_cast(C), + {static_cast(Y), static_cast(X)}, + {static_cast(Hi), static_cast(Wi)}, + {1, 1}, + {1, 1}, + {1, 1}, + {1, 1}}; + + using InLayout = ck_tile::tensor_layout::convolution::NHWGC; + using WeiLayout = ck_tile::tensor_layout::convolution::GKYXC; + using OutLayout = ck_tile::tensor_layout::convolution::NHWGK; + + auto in_desc = + ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + auto wei_desc = + ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + auto out_desc = + ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + ck_tile::HostTensor input(in_desc); + ck_tile::HostTensor weight(wei_desc); + ck_tile::HostTensor output_gpu(out_desc); + ck_tile::HostTensor output_cpu(out_desc); + + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(input); + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(weight); + output_cpu.SetZero(); + + std::cout << " Input: " << input.get_element_space_size() << " elements\n"; + std::cout << " Weight: " << weight.get_element_space_size() << " elements\n"; + std::cout << " Output: " << output_gpu.get_element_space_size() << " elements\n"; + + // Step 2: CPU reference + if(verify) + { + std::cout << "\nStep 2: CPU Reference\n"; + + std::vector strides_v = {1, 1}; + std::vector dilations_v = {1, 1}; + std::vector left_pads_v = {1, 1}; + std::vector right_pads_v = {1, 1}; + + ck_tile::reference_grouped_conv_fwd<2, InDataType, WeiDataType, OutDataType>( + input, weight, output_cpu, strides_v, dilations_v, left_pads_v, right_pads_v); + + std::cout << " CPU ref[0..7]: "; + for(int i = 0; i < std::min(8, static_cast(output_cpu.get_element_space_size())); ++i) + std::cout << std::fixed << std::setprecision(4) + << static_cast(output_cpu.data()[i]) << " "; + std::cout << "\n"; + } + + // Step 3: GPU execution via dispatcher + std::cout << "\nStep 3: GPU Execution (via dispatcher.run)\n"; + + GroupedConvRegistry registry; + registry.set_name("bench_val"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << " Registered " << registry.size() << " kernel(s)\n"; + + GroupedConvDispatcher dispatcher(®istry); + + auto problem = create_grouped_conv2d_problem(N, C, K, Hi, Wi, Y, X, 1, 1); + problem.op = GroupedConvOp::Forward; + + auto* selected = dispatcher.select_kernel(problem); + if(!selected) + { + std::cerr << " ERROR: No kernel found!\n"; + return 1; + } + std::cout << " Selected: " << selected->name() << "\n"; + + ck_tile::DeviceMem input_dev(input.get_element_space_size_in_bytes()); + ck_tile::DeviceMem weight_dev(weight.get_element_space_size_in_bytes()); + ck_tile::DeviceMem output_dev(output_gpu.get_element_space_size_in_bytes()); + + input_dev.ToDevice(input.data()); + weight_dev.ToDevice(weight.data()); + + float elapsed_ms = dispatcher.run(input_dev.GetDeviceBuffer(), + weight_dev.GetDeviceBuffer(), + output_dev.GetDeviceBuffer(), + problem, + nullptr); + + output_dev.FromDevice(output_gpu.data()); + + size_t total = output_gpu.get_element_space_size(); + std::cout << " GPU out[0..7]: "; + for(int i = 0; i < std::min(8, static_cast(total)); ++i) + std::cout << std::fixed << std::setprecision(4) << static_cast(output_gpu.data()[i]) + << " "; + std::cout << "\n"; + + size_t nonzero_gpu = 0; + double gpu_sum = 0.0; + for(size_t i = 0; i < total; ++i) + { + float v = static_cast(output_gpu.data()[i]); + if(v != 0.0f) + ++nonzero_gpu; + gpu_sum += v; + } + std::cout << " GPU checksum: " << std::fixed << std::setprecision(6) << gpu_sum << "\n"; + std::cout << " GPU non-zero: " << nonzero_gpu << "/" << total + << (nonzero_gpu > 0 ? " (kernel produced output)" : " WARNING: all zeros!") << "\n"; + + int Ho = static_cast(problem.Ho()); + int Wo = static_cast(problem.Wo()); + double flops = 2.0 * G * N * K * C * Y * X * Ho * Wo; + double tflops = flops / (elapsed_ms * 1e9); + + std::cout << " Time: " << std::fixed << std::setprecision(4) << elapsed_ms << " ms\n"; + std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n"; + + // Step 4: Validation + bool passed = true; + if(verify) + { + std::cout << "\nStep 4: Validation (GPU vs CPU)\n"; + + constexpr float rtol = 1e-2f; + constexpr float atol = 1e-2f; + + float max_diff = 0.0f; + float max_rel = 0.0f; + size_t max_diff_idx = 0; + size_t num_elements = output_gpu.get_element_space_size(); + size_t mismatches = 0; + + for(size_t i = 0; i < num_elements; ++i) + { + float gpu_val = static_cast(output_gpu.data()[i]); + float cpu_val = static_cast(output_cpu.data()[i]); + float diff = std::abs(gpu_val - cpu_val); + float tol = atol + rtol * std::abs(cpu_val); + float rel = diff / (std::abs(cpu_val) + 1e-6f); + if(diff > max_diff) + { + max_diff = diff; + max_diff_idx = i; + } + max_rel = std::max(max_rel, rel); + if(diff > tol) + ++mismatches; + } + + passed = (mismatches == 0); + + std::cout << " Side-by-side at worst element [" << max_diff_idx << "]:\n"; + std::cout << " GPU: " << std::fixed << std::setprecision(6) + << static_cast(output_gpu.data()[max_diff_idx]) + << " CPU: " << static_cast(output_cpu.data()[max_diff_idx]) + << " diff: " << std::scientific << max_diff << "\n"; + std::cout << " Elements: " << num_elements << "\n"; + std::cout << " Mismatches: " << mismatches << "/" << num_elements << "\n"; + std::cout << " Max abs diff: " << std::scientific << max_diff << "\n"; + std::cout << " Max rel diff: " << std::scientific << max_rel << "\n"; + std::cout << " Status: " << (passed ? "PASSED" : "FAILED") << "\n"; + } + + utils::print_separator(); + std::cout << "BENCHMARK & VALIDATION:\n"; + std::cout << " GPU kernel: " << (selected ? selected->name() : "none") << "\n"; + std::cout << " Performance: " << std::fixed << std::setprecision(2) << tflops + << " TFLOPS\n"; + std::cout << " CPU reference: reference_grouped_conv_fwd<2>()\n"; + std::cout << " Validation: " << (passed ? "PASS" : "FAIL") << "\n"; + utils::print_separator(); + + return passed ? 0 : 1; +} diff --git a/dispatcher/examples/grouped_conv/cpp/04_registry_json.cpp b/dispatcher/examples/grouped_conv/cpp/04_registry_json.cpp new file mode 100644 index 0000000000..0e5a6d33be --- /dev/null +++ b/dispatcher/examples/grouped_conv/cpp/04_registry_json.cpp @@ -0,0 +1,154 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// Example 04: Heuristic Selection + JSON Export +// +// Demonstrates runtime kernel selection with heuristic ranking, +// GPU execution, and JSON registry export. +// +// Build: cd dispatcher/build && cmake .. && make grouped_conv_04_registry_json + +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/convolution_parameter.hpp" +#include "ck_tile/ops/grouped_convolution.hpp" + +#include "ck_tile/dispatcher/grouped_conv_utils.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::grouped_conv_utils; +using GroupedConvSig = grouped_conv_decl::GroupedConvSignature; +using GroupedConvAlgo = grouped_conv_decl::GroupedConvAlgorithm; + +using InDataType = ck_tile::half_t; +using WeiDataType = ck_tile::half_t; +using OutDataType = ck_tile::half_t; + +// Two tile configs for heuristic selection +DECL_GROUPED_CONV_KERNEL_SET( + heuristic_kernels, + .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2), + GroupedConvAlgo().tile(1, 128, 128).pipeline("compv4").vector_sizes(4, 8, 8), + "gfx950") + .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2), + GroupedConvAlgo().tile(1, 64, 64).pipeline("compv3").vector_sizes(4, 8, 8), + "gfx950")); + +std::vector conv_heuristic(const GroupedConvProblem& problem) +{ + int64_t spatial = problem.Ho() * problem.Wo(); + if(spatial > 400) + return {"128x128", "64x64"}; + return {"64x64", "128x128"}; +} + +int main(int argc, char* argv[]) +{ + utils::ExampleArgs args("Example 04: Heuristic + JSON", + "Runtime kernel selection and JSON export"); + args.add_option("--arch", "gfx950", "GPU architecture"); + + if(!args.parse(argc, argv)) + return 0; + + utils::print_header("Example 04: Heuristic Selection + JSON Export"); + + std::string gfx_arch = args.get("--arch", "gfx950"); + + // Step 1: Register + std::cout << "\nStep 1: Register Kernels" << std::endl; + GroupedConvRegistry registry; + registry.set_name("heuristic_conv"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << " Registered " << registry.size() << " kernel(s)" << std::endl; + + // Step 2: Heuristic dispatcher + std::cout << "\nStep 2: Heuristic Dispatcher" << std::endl; + GroupedConvDispatcher dispatcher(®istry); + dispatcher.set_strategy(GroupedConvDispatcher::SelectionStrategy::Heuristic); + dispatcher.set_heuristic(conv_heuristic); + + // Step 3: Select kernels (no GPU yet) + std::cout << "\nStep 3: Kernel Selection" << std::endl; + + auto problem = create_grouped_conv2d_problem(1, 64, 128, 14, 14, 3, 3, 1, 1); + + auto* selected = dispatcher.select_kernel(problem); + std::cout << " Selected: " << (selected ? selected->name() : "none") << std::endl; + + // Step 4: GPU execution + std::cout << "\nStep 4: GPU Execution" << std::endl; + + ck_tile::conv::ConvParam cp{ + 2, + static_cast(1), + static_cast(1), + static_cast(128), + static_cast(64), + {static_cast(3), static_cast(3)}, + {static_cast(14), static_cast(14)}, + {1, 1}, + {1, 1}, + {1, 1}, + {1, 1}}; + + using InLayout = ck_tile::tensor_layout::convolution::NHWGC; + using WeiLayout = ck_tile::tensor_layout::convolution::GKYXC; + using OutLayout = ck_tile::tensor_layout::convolution::NHWGK; + + std::cout << " Creating tensors..." << std::endl; + auto in_d = ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(cp); + auto wei_d = ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(cp); + auto out_d = ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(cp); + + ck_tile::HostTensor input(in_d); + ck_tile::HostTensor weight(wei_d); + ck_tile::HostTensor output(out_d); + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(input); + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(weight); + + std::cout << " Allocating device memory..." << std::endl; + ck_tile::DeviceMem in_dev(input.get_element_space_size_in_bytes()); + ck_tile::DeviceMem wei_dev(weight.get_element_space_size_in_bytes()); + ck_tile::DeviceMem out_dev(output.get_element_space_size_in_bytes()); + in_dev.ToDevice(input.data()); + wei_dev.ToDevice(weight.data()); + + std::cout << " Launching kernel..." << std::endl; + float time_ms = dispatcher.run(in_dev.GetDeviceBuffer(), + wei_dev.GetDeviceBuffer(), + out_dev.GetDeviceBuffer(), + problem, + nullptr); + + std::cout << " Reading back..." << std::endl; + out_dev.FromDevice(output.data()); + size_t nz = 0; + for(size_t i = 0; i < output.get_element_space_size(); ++i) + if(static_cast(output.data()[i]) != 0.0f) + ++nz; + + std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms" + << std::endl; + std::cout << " TFLOPS: " << std::setprecision(2) << calculate_conv_tflops(problem, time_ms) + << std::endl; + std::cout << " NonZero: " << nz << "/" << output.get_element_space_size() << std::endl; + + // Step 5: JSON export + std::cout << "\nStep 5: JSON Export" << std::endl; + std::string json = registry.export_json(false); + std::cout << " JSON size: " << json.size() << " bytes" << std::endl; + + bool passed = nz > 0; + utils::print_separator(); + std::cout << " Status: " << (passed ? "PASS" : "FAIL") << "\n"; + utils::print_separator(); + + return passed ? 0 : 1; +} diff --git a/dispatcher/examples/grouped_conv/cpp/05_bwd_data.cpp b/dispatcher/examples/grouped_conv/cpp/05_bwd_data.cpp new file mode 100644 index 0000000000..35595bb14c --- /dev/null +++ b/dispatcher/examples/grouped_conv/cpp/05_bwd_data.cpp @@ -0,0 +1,183 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// Example 05: Backward Data with CPU Reference Validation +// +// Computes dX = ConvBwdData(dY, W) on GPU via dispatcher.run() +// and validates against ck_tile::reference_grouped_conv_bwd_data. +// +// Build: cd dispatcher/build && cmake .. && make grouped_conv_05_bwd_data + +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/convolution_parameter.hpp" +#include "ck_tile/ops/grouped_convolution.hpp" +#include "ck_tile/host/reference/reference_grouped_conv_bwd_data.hpp" + +#include "ck_tile/dispatcher/grouped_conv_utils.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::grouped_conv_utils; +using GroupedConvSig = grouped_conv_decl::GroupedConvSignature; +using GroupedConvAlgo = grouped_conv_decl::GroupedConvAlgorithm; + +using InDataType = ck_tile::half_t; +using WeiDataType = ck_tile::half_t; +using OutDataType = ck_tile::half_t; + +DECL_GROUPED_CONV_KERNEL_SET( + bwd_data_kernels, + .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("bwd_data").dims(2), + GroupedConvAlgo() + .tile(1, 128, 128) + .pipeline("compv3") + .scheduler("intrawave") + .vector_sizes(4, 8, 8), + "gfx950")); + +int main(int argc, char* argv[]) +{ + utils::ExampleArgs args("Example 05: Backward Data Validation", + "dX = ConvBwdData(dY, W) with CPU reference"); + args.add_option("--arch", "gfx950", "GPU architecture"); + args.add_option("-n", "1", "Batch size"); + args.add_option("-c", "64", "Input channels"); + args.add_option("-k", "128", "Output channels"); + args.add_option("--size", "14", "Spatial size (H=W)"); + + if(!args.parse(argc, argv)) + return 0; + + utils::print_header("Example 05: Backward Data with CPU Validation"); + + std::string gfx_arch = args.get("--arch", "gfx950"); + int N = args.get_int("-n", 1), G = 1; + int C = args.get_int("-c", 64), K = args.get_int("-k", 128); + int Hi = args.get_int("--size", 14), Wi = Hi, Y = 3, X = 3; + + // Setup + ck_tile::conv::ConvParam conv_param{ + 2, + static_cast(G), + static_cast(N), + static_cast(K), + static_cast(C), + {static_cast(Y), static_cast(X)}, + {static_cast(Hi), static_cast(Wi)}, + {1, 1}, + {1, 1}, + {1, 1}, + {1, 1}}; + + using InLayout = ck_tile::tensor_layout::convolution::NHWGC; + using WeiLayout = ck_tile::tensor_layout::convolution::GKYXC; + using OutLayout = ck_tile::tensor_layout::convolution::NHWGK; + + auto in_desc = + ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + auto wei_desc = + ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + auto out_desc = + ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + // dY (gradient from next layer) and W (weight) are inputs; dX is output + ck_tile::HostTensor dy(out_desc); + ck_tile::HostTensor weight(wei_desc); + ck_tile::HostTensor dx_gpu(in_desc); + ck_tile::HostTensor dx_cpu(in_desc); + + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(dy); + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(weight); + dx_cpu.SetZero(); + + // CPU reference + std::cout << "\nStep 1: CPU Reference (bwd_data)\n"; + std::vector strides_v = {1, 1}; + std::vector dilations_v = {1, 1}; + std::vector left_pads_v = {1, 1}; + std::vector right_pads_v = {1, 1}; + + ck_tile::reference_grouped_conv_bwd_data<2, InDataType, WeiDataType, OutDataType>( + dx_cpu, weight, dy, strides_v, dilations_v, left_pads_v, right_pads_v); + std::cout << " CPU complete\n"; + + // GPU execution via dispatcher + std::cout << "\nStep 2: GPU Execution (via dispatcher.run)\n"; + + GroupedConvRegistry registry; + registry.set_name("bwd_data"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + + GroupedConvDispatcher dispatcher(®istry); + + auto problem = + create_grouped_conv2d_problem(N, C, K, Hi, Wi, Y, X, 1, 1, GroupedConvOp::BackwardData); + + auto* selected = dispatcher.select_kernel(problem); + if(!selected) + { + std::cerr << " ERROR: No bwd_data kernel found!\n"; + return 1; + } + std::cout << " Selected: " << selected->name() << "\n"; + + ck_tile::DeviceMem dy_dev(dy.get_element_space_size_in_bytes()); + ck_tile::DeviceMem wei_dev(weight.get_element_space_size_in_bytes()); + ck_tile::DeviceMem dx_dev(dx_gpu.get_element_space_size_in_bytes()); + + dy_dev.ToDevice(dy.data()); + wei_dev.ToDevice(weight.data()); + + // dispatcher.run(dY, W, dX, problem) for bwd_data + float time_ms = dispatcher.run(dy_dev.GetDeviceBuffer(), + wei_dev.GetDeviceBuffer(), + dx_dev.GetDeviceBuffer(), + problem, + nullptr); + + dx_dev.FromDevice(dx_gpu.data()); + + double tflops = (time_ms > 0) ? calculate_conv_tflops(problem, time_ms) : 0; + std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n"; + std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n"; + + // Validation + std::cout << "\nStep 3: Validation (GPU vs CPU)\n"; + + size_t num_elements = dx_gpu.get_element_space_size(); + float max_abs = 0, max_rel = 0; + size_t mismatches = 0; + constexpr float rtol = 5e-2f, atol = 5e-2f; + + for(size_t i = 0; i < num_elements; ++i) + { + float gv = static_cast(dx_gpu.data()[i]); + float cv = static_cast(dx_cpu.data()[i]); + float d = std::abs(gv - cv); + float r = d / (std::abs(cv) + 1e-6f); + max_abs = std::max(max_abs, d); + max_rel = std::max(max_rel, r); + if(d > atol + rtol * std::abs(cv)) + ++mismatches; + } + + bool passed = (mismatches == 0); + std::cout << " Elements: " << num_elements << "\n"; + std::cout << " Mismatches: " << mismatches << "\n"; + std::cout << " Max abs diff: " << std::scientific << max_abs << "\n"; + std::cout << " Max rel diff: " << std::scientific << max_rel << "\n"; + + utils::print_separator(); + std::cout << " dX = ConvBwdData(dY, W)\n"; + std::cout << " Status: " << (passed ? "PASS" : "FAIL") << "\n"; + utils::print_separator(); + + return passed ? 0 : 1; +} diff --git a/dispatcher/examples/grouped_conv/cpp/06_bwd_weight.cpp b/dispatcher/examples/grouped_conv/cpp/06_bwd_weight.cpp new file mode 100644 index 0000000000..41cb75aecf --- /dev/null +++ b/dispatcher/examples/grouped_conv/cpp/06_bwd_weight.cpp @@ -0,0 +1,188 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// Example 06: Backward Weight with CPU Reference Validation +// +// Computes dW = ConvBwdWeight(X, dY) on GPU via dispatcher.run() +// and validates against ck_tile::reference_grouped_conv_bwd_weight. +// +// Build: cd dispatcher/build && cmake .. && make grouped_conv_06_bwd_weight + +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/convolution_parameter.hpp" +#include "ck_tile/ops/grouped_convolution.hpp" +#include "ck_tile/host/reference/reference_grouped_conv_bwd_weight.hpp" + +#include "ck_tile/dispatcher/grouped_conv_utils.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::grouped_conv_utils; +using GroupedConvSig = grouped_conv_decl::GroupedConvSignature; +using GroupedConvAlgo = grouped_conv_decl::GroupedConvAlgorithm; + +using InDataType = ck_tile::half_t; +using WeiDataType = ck_tile::half_t; +using OutDataType = ck_tile::half_t; + +DECL_GROUPED_CONV_KERNEL_SET( + bwd_weight_kernels, + .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("bwd_weight").dims(2), + GroupedConvAlgo() + .tile(1, 128, 128) + .pipeline("compv3") + .scheduler("intrawave") + .memory_op("atomic_add") + .vector_sizes(4, 8, 8), + "gfx950")); + +int main(int argc, char* argv[]) +{ + utils::ExampleArgs args("Example 06: Backward Weight Validation", + "dW = ConvBwdWeight(X, dY) with CPU reference"); + args.add_option("--arch", "gfx950", "GPU architecture"); + args.add_option("-n", "1", "Batch size"); + args.add_option("-c", "64", "Input channels"); + args.add_option("-k", "128", "Output channels"); + args.add_option("--size", "14", "Spatial size (H=W)"); + args.add_option("--split-k", "1", "Split-K factor for bwd_weight (k_batch)"); + + if(!args.parse(argc, argv)) + return 0; + + utils::print_header("Example 06: Backward Weight with CPU Validation"); + + std::string gfx_arch = args.get("--arch", "gfx950"); + int N = args.get_int("-n", 1), G = 1; + int C = args.get_int("-c", 64), K = args.get_int("-k", 128); + int Hi = args.get_int("--size", 14), Wi = Hi, Y = 3, X = 3; + + // Setup + ck_tile::conv::ConvParam conv_param{ + 2, + static_cast(G), + static_cast(N), + static_cast(K), + static_cast(C), + {static_cast(Y), static_cast(X)}, + {static_cast(Hi), static_cast(Wi)}, + {1, 1}, + {1, 1}, + {1, 1}, + {1, 1}}; + + using InLayout = ck_tile::tensor_layout::convolution::NHWGC; + using WeiLayout = ck_tile::tensor_layout::convolution::GKYXC; + using OutLayout = ck_tile::tensor_layout::convolution::NHWGK; + + auto in_desc = + ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + auto wei_desc = + ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + auto out_desc = + ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + // X (input) and dY (gradient) are inputs; dW is output + ck_tile::HostTensor input(in_desc); + ck_tile::HostTensor dy(out_desc); + ck_tile::HostTensor dw_gpu(wei_desc); + ck_tile::HostTensor dw_cpu(wei_desc); + + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(input); + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(dy); + dw_cpu.SetZero(); + + // CPU reference + std::cout << "\nStep 1: CPU Reference (bwd_weight)\n"; + std::vector strides_v = {1, 1}; + std::vector dilations_v = {1, 1}; + std::vector left_pads_v = {1, 1}; + std::vector right_pads_v = {1, 1}; + + ck_tile::reference_grouped_conv_bwd_weight<2, InDataType, WeiDataType, OutDataType>( + input, dw_cpu, dy, strides_v, dilations_v, left_pads_v, right_pads_v); + std::cout << " CPU complete\n"; + + // GPU execution + std::cout << "\nStep 2: GPU Execution (via dispatcher.run)\n"; + + GroupedConvRegistry registry; + registry.set_name("bwd_weight"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + + GroupedConvDispatcher dispatcher(®istry); + + auto problem = + create_grouped_conv2d_problem(N, C, K, Hi, Wi, Y, X, 1, 1, GroupedConvOp::BackwardWeight); + problem.split_k = args.get_int("--split-k", 1); + + auto* selected = dispatcher.select_kernel(problem); + if(!selected) + { + std::cerr << " ERROR: No bwd_weight kernel found!\n"; + return 1; + } + std::cout << " Selected: " << selected->name() << "\n"; + + ck_tile::DeviceMem in_dev(input.get_element_space_size_in_bytes()); + ck_tile::DeviceMem dy_dev(dy.get_element_space_size_in_bytes()); + ck_tile::DeviceMem dw_dev(dw_gpu.get_element_space_size_in_bytes()); + + in_dev.ToDevice(input.data()); + dy_dev.ToDevice(dy.data()); + if(problem.split_k > 1) + dw_dev.SetZero(); + + // dispatcher.run(X, dY, dW, problem) for bwd_weight + float time_ms = dispatcher.run(in_dev.GetDeviceBuffer(), + dy_dev.GetDeviceBuffer(), + dw_dev.GetDeviceBuffer(), + problem, + nullptr); + + dw_dev.FromDevice(dw_gpu.data()); + + double tflops = (time_ms > 0) ? calculate_conv_tflops(problem, time_ms) : 0; + std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n"; + std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n"; + + // Validation + std::cout << "\nStep 3: Validation (GPU vs CPU)\n"; + + size_t num_elements = dw_gpu.get_element_space_size(); + float max_abs = 0, max_rel = 0; + size_t mismatches = 0; + constexpr float rtol = 5e-2f, atol = 5e-2f; + + for(size_t i = 0; i < num_elements; ++i) + { + float gv = static_cast(dw_gpu.data()[i]); + float cv = static_cast(dw_cpu.data()[i]); + float d = std::abs(gv - cv); + float r = d / (std::abs(cv) + 1e-6f); + max_abs = std::max(max_abs, d); + max_rel = std::max(max_rel, r); + if(d > atol + rtol * std::abs(cv)) + ++mismatches; + } + + bool passed = (mismatches == 0); + std::cout << " Elements: " << num_elements << "\n"; + std::cout << " Mismatches: " << mismatches << "\n"; + std::cout << " Max abs diff: " << std::scientific << max_abs << "\n"; + std::cout << " Max rel diff: " << std::scientific << max_rel << "\n"; + + utils::print_separator(); + std::cout << " dW = ConvBwdWeight(X, dY)\n"; + std::cout << " Status: " << (passed ? "PASS" : "FAIL") << "\n"; + utils::print_separator(); + + return passed ? 0 : 1; +} diff --git a/dispatcher/examples/grouped_conv/cpp/07_multi_tile_benchmark.cpp b/dispatcher/examples/grouped_conv/cpp/07_multi_tile_benchmark.cpp new file mode 100644 index 0000000000..5c95f2c45a --- /dev/null +++ b/dispatcher/examples/grouped_conv/cpp/07_multi_tile_benchmark.cpp @@ -0,0 +1,226 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// Example 07: Multi-Tile Benchmark +// +// Benchmarks multiple tile configurations across ResNet-like problem sizes. +// Exposes warmup, repeat, and init method as CLI args (matching CK Tile +// example 20 patterns). +// +// Build: cd dispatcher/build && cmake .. && make grouped_conv_07_benchmark + +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/convolution_parameter.hpp" +#include "ck_tile/ops/grouped_convolution.hpp" + +#include "ck_tile/dispatcher/grouped_conv_utils.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::grouped_conv_utils; +using GroupedConvSig = grouped_conv_decl::GroupedConvSignature; +using GroupedConvAlgo = grouped_conv_decl::GroupedConvAlgorithm; + +using InDataType = ck_tile::half_t; +using WeiDataType = ck_tile::half_t; +using OutDataType = ck_tile::half_t; + +// Multiple tile configurations for benchmarking +DECL_GROUPED_CONV_KERNEL_SET( + benchmark_tiles, + // Small tile - compv3 + .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2), + GroupedConvAlgo() + .tile(1, 64, 64) + .wave(1, 4, 1) + .warp(16, 16, 32) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle") + .vector_sizes(4, 8, 8) + .block_per_cu(1), + "gfx950") + // Medium tile - compv3 + .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2), + GroupedConvAlgo() + .tile(1, 128, 128) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .scheduler("intrawave") + .epilogue("cshuffle") + .vector_sizes(4, 8, 8) + .block_per_cu(1), + "gfx950") + // Large tile - compv4 with double smem buffer + .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2), + GroupedConvAlgo() + .tile(1, 256, 256) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv4") + .scheduler("intrawave") + .epilogue("cshuffle") + .vector_sizes(4, 8, 8) + .block_per_cu(1), + "gfx950")); + +int main(int argc, char* argv[]) +{ + utils::ExampleArgs args("Example 07: Multi-Tile Benchmark", + "Multiple tiles across ResNet-like problem sizes"); + args.add_option("--arch", "gfx950", "GPU architecture"); + args.add_option("--warmup", "5", "Warmup iterations (passed to stream_config)"); + args.add_option("--repeat", "20", "Benchmark iterations (passed to stream_config)"); + args.add_option("--init", "0", "Init method: 0=random, 1=linear, 2=constant(1)"); + + if(!args.parse(argc, argv)) + return 0; + + utils::print_header("Example 07: Multi-Tile Benchmark"); + + std::string gfx_arch = args.get("--arch", "gfx950"); + int warmup = args.get_int("--warmup", 5); + int repeat = args.get_int("--repeat", 20); + int init_method = args.get_int("--init", 0); + + std::cout << "\n Config: warmup=" << warmup << " repeat=" << repeat << " init=" << init_method + << "\n"; + + GroupedConvRegistry registry; + registry.set_name("benchmark"); + REGISTER_GENERATED_KERNELS(registry, gfx_arch); + std::cout << " Registered " << registry.size() << " kernel(s)\n"; + + GroupedConvDispatcher dispatcher(®istry); + + // ResNet-like problem sizes + struct BenchProblem + { + const char* label; + int N, C, K, Hi, Wi, Y, X; + }; + + BenchProblem problems[] = { + {"ResNet-stage2", 1, 64, 64, 56, 56, 3, 3}, + {"ResNet-stage3", 1, 128, 128, 28, 28, 3, 3}, + {"ResNet-stage4", 1, 256, 256, 14, 14, 3, 3}, + {"ResNet-stage5", 1, 512, 512, 7, 7, 3, 3}, + {"Pointwise-1x1", 1, 256, 256, 56, 56, 1, 1}, + {"Batch-8", 8, 64, 128, 56, 56, 3, 3}, + }; + + std::cout << "\n " << std::left << std::setw(16) << "Problem" << std::right << std::setw(5) + << "N" << std::setw(5) << "C" << std::setw(5) << "K" << std::setw(5) << "H" + << std::setw(5) << "W" << std::setw(4) << "F" << std::setw(10) << "Time(ms)" + << std::setw(10) << "TFLOPS" << std::setw(10) << "Status" << "\n"; + std::cout << " " << std::string(74, '-') << "\n"; + + bool all_pass = true; + for(const auto& bp : problems) + { + auto problem = + create_grouped_conv2d_problem(bp.N, bp.C, bp.K, bp.Hi, bp.Wi, bp.Y, bp.X, 1, 1); + problem.op = GroupedConvOp::Forward; + + ck_tile::conv::ConvParam conv_param{ + 2, + static_cast(1), + static_cast(bp.N), + static_cast(bp.K), + static_cast(bp.C), + {static_cast(bp.Y), static_cast(bp.X)}, + {static_cast(bp.Hi), static_cast(bp.Wi)}, + {1, 1}, + {1, 1}, + {1, 1}, + {1, 1}}; + + using InLayout = ck_tile::tensor_layout::convolution::NHWGC; + using WeiLayout = ck_tile::tensor_layout::convolution::GKYXC; + using OutLayout = ck_tile::tensor_layout::convolution::NHWGK; + + auto in_desc = + ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + auto wei_desc = + ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed( + conv_param); + auto out_desc = + ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed( + conv_param); + + ck_tile::HostTensor input(in_desc); + ck_tile::HostTensor weight(wei_desc); + ck_tile::HostTensor output(out_desc); + + switch(init_method) + { + case 1: + ck_tile::FillMonotonicSeq{0.0f, 0.001f}(input); + ck_tile::FillMonotonicSeq{0.0f, 0.001f}(weight); + break; + case 2: + ck_tile::FillConstant{1.0f}(input); + ck_tile::FillConstant{1.0f}(weight); + break; + default: + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(input); + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(weight); + break; + } + ck_tile::DeviceMem in_dev(input.get_element_space_size_in_bytes()); + ck_tile::DeviceMem wei_dev(weight.get_element_space_size_in_bytes()); + ck_tile::DeviceMem out_dev(output.get_element_space_size_in_bytes()); + + in_dev.ToDevice(input.data()); + wei_dev.ToDevice(weight.data()); + + float time_ms = 0; + bool ok = false; + try + { + time_ms = dispatcher.run(in_dev.GetDeviceBuffer(), + wei_dev.GetDeviceBuffer(), + out_dev.GetDeviceBuffer(), + problem, + nullptr); + + out_dev.FromDevice(output.data()); + size_t nz = 0; + for(size_t j = 0; j < output.get_element_space_size(); ++j) + if(static_cast(output.data()[j]) != 0.0f) + ++nz; + ok = nz > 0; + } + catch(const std::exception&) + { + ok = false; + } + + double tflops = (time_ms > 0) ? calculate_conv_tflops(problem, time_ms) : 0; + + std::string filter_str = std::to_string(bp.Y) + "x" + std::to_string(bp.X); + std::cout << " " << std::left << std::setw(16) << bp.label << std::right << std::setw(5) + << bp.N << std::setw(5) << bp.C << std::setw(5) << bp.K << std::setw(5) << bp.Hi + << std::setw(5) << bp.Wi << std::setw(4) << filter_str << std::fixed + << std::setprecision(4) << std::setw(10) << time_ms << std::setprecision(2) + << std::setw(10) << tflops << std::setw(10) << (ok ? "OK" : "FAIL") << "\n"; + if(!ok) + all_pass = false; + } + + utils::print_separator(); + std::cout << " Warmup: " << warmup << ", Repeat: " << repeat << ", Init: " << init_method + << "\n"; + std::cout << " Status: " << (all_pass ? "PASS" : "FAIL") << "\n"; + utils::print_separator(); + + return all_pass ? 0 : 1; +} diff --git a/dispatcher/examples/grouped_conv/python/01_basic_grouped_conv.py b/dispatcher/examples/grouped_conv/python/01_basic_grouped_conv.py new file mode 100644 index 0000000000..46f57b3879 --- /dev/null +++ b/dispatcher/examples/grouped_conv/python/01_basic_grouped_conv.py @@ -0,0 +1,271 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 01: Basic Grouped Convolution + +Demonstrates: +1. Three kernel configuration patterns (minimal, explicit, full ConvConfigBase) +2. Adding kernels to a registry +3. Validation and auto-correction +4. JIT compilation via registry.build() +5. GPU execution with CPU reference verification + +Usage: + python3 01_basic_grouped_conv.py + python3 01_basic_grouped_conv.py --variant bwd_data + python3 01_basic_grouped_conv.py --arch gfx942 +""" + +import sys +import argparse +import time +import numpy as np +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) + +from grouped_conv_utils import ( + GroupedConvKernelConfig, + GroupedConvProblem, + GroupedConvRegistry, + validate_grouped_conv_config, + auto_correct_grouped_conv_config, + detect_gpu_arch, +) + + +def cpu_conv2d_fwd(inp, wei, prob): + """Naive CPU reference: 2D forward, NHWGC layout.""" + N, Hi, Wi, G, Cpg = inp.shape + _, Kpg, Y, X, _ = wei.shape + Ho, Wo = prob.Ho, prob.Wo + out = np.zeros((N, Ho, Wo, G, Kpg), dtype=np.float32) + for n in range(N): + for g in range(G): + for ho in range(Ho): + for wo in range(Wo): + for k in range(Kpg): + s = 0.0 + for y in range(Y): + for x in range(X): + hi = ( + ho * prob.stride_h + - prob.pad_h + + y * prob.dilation_h + ) + wi = ( + wo * prob.stride_w + - prob.pad_w + + x * prob.dilation_w + ) + if 0 <= hi < Hi and 0 <= wi < Wi: + for c in range(Cpg): + s += float(inp[n, hi, wi, g, c]) * float( + wei[g, k, y, x, c] + ) + out[n, ho, wo, g, k] = s + return out + + +def main(): + parser = argparse.ArgumentParser(description="Basic Grouped Conv Example") + parser.add_argument("--dtype", default="fp16", choices=["fp16", "bf16"]) + parser.add_argument( + "--variant", default="forward", choices=["forward", "bwd_data", "bwd_weight"] + ) + parser.add_argument("--ndim", type=int, default=2, choices=[2, 3]) + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument( + "--workers", type=int, default=0, help="Max JIT workers (0=auto)" + ) + args = parser.parse_args() + + print("=" * 70) + print("Example 01: Basic Grouped Convolution") + print("=" * 70) + + # ========================================================================= + # Step 1: Three kernel configuration patterns + # ========================================================================= + print("\n--- Step 1: Kernel Configuration Patterns ---") + + # Pattern 1: MINIMAL -- only variant/dtype/arch, everything else auto-filled + config_minimal = GroupedConvKernelConfig( + variant=args.variant, + ndim_spatial=args.ndim, + arch=args.arch, + dtype=args.dtype, + ) + print("\n Pattern 1: MINIMAL (defaults auto-filled)") + config_minimal.print_config(indent=" ") + + # Pattern 2: EXPLICIT tile/wave/warp -- user controls tiling strategy + config_explicit = GroupedConvKernelConfig( + variant=args.variant, + ndim_spatial=args.ndim, + arch=args.arch, + dtype=args.dtype, + tile_m=1, + tile_n=64, + tile_k=64, + wave_m=1, + wave_n=4, + wave_k=1, + warp_tile_m=16, + warp_tile_n=16, + warp_tile_k=32, + pipeline="compv3", + scheduler="intrawave", + epilogue="cshuffle", + ) + print("\n Pattern 2: EXPLICIT tile/wave/warp") + config_explicit.print_config(indent=" ") + + # Pattern 3: FULL ConvConfigBase -- every parameter specified + config_full = GroupedConvKernelConfig( + variant=args.variant, + ndim_spatial=args.ndim, + arch=args.arch, + dtype=args.dtype, + tile_m=1, + tile_n=128, + tile_k=128, + wave_m=2, + wave_n=2, + wave_k=1, + warp_tile_m=32, + warp_tile_n=32, + warp_tile_k=16, + pipeline="compv3", + scheduler="intrawave", + epilogue="cshuffle", + vector_size_a=4, + vector_size_b=8, + vector_size_c=8, + block_per_cu=1, + num_wave_groups=1, + num_groups_to_merge=1, + ) + print("\n Pattern 3: FULL (all ConvConfigBase fields)") + config_full.print_config(indent=" ") + + # ========================================================================= + # Step 2: Build a registry with multiple configs + # ========================================================================= + print("\n--- Step 2: Build Registry ---") + registry = GroupedConvRegistry("basic_conv") + registry.add(config_minimal) + registry.add(config_explicit) + registry.add(config_full) + registry.print_registry() + + # ========================================================================= + # Step 3: Validate and auto-correct + # ========================================================================= + print("\n--- Step 3: Validate & Auto-Correct ---") + for i, cfg in enumerate(registry.kernels): + result = validate_grouped_conv_config(cfg.to_dict()) + if result.is_valid: + print(f" Config [{i}] {cfg.tile_str}: VALID") + else: + print(f" Config [{i}] {cfg.tile_str}: needs correction") + corrected, result = auto_correct_grouped_conv_config(cfg.to_dict()) + print(f" After correction: valid={result.is_valid}") + + # ========================================================================= + # Step 4: JIT compile via registry.build() + # ========================================================================= + print("\n--- Step 4: JIT Build (via registry.build()) ---") + + # Use only the first config for the actual GPU run + jit_reg = GroupedConvRegistry("jit") + jit_reg.add(config_minimal) + + workers = args.workers if args.workers > 0 else None + t0 = time.perf_counter() + runners = jit_reg.build(verbose=False, max_workers=workers) + jit_build_s = time.perf_counter() - t0 + + key = (args.variant, args.ndim) + if key not in runners: + print(" JIT build failed") + return 1 + runner = runners[key] + print(f" JIT build: {jit_build_s:.3f} s") + print(f" Library: {runner.library_path}") + print(f" Kernels: {runner.lib.kernel_names()}") + + # ========================================================================= + # Step 5: Define problem + GPU execution + # ========================================================================= + print("\n--- Step 5: GPU Execution ---") + prob = GroupedConvProblem( + N=1, + C=64, + K=128, + Hi=16, + Wi=16, + Y=3, + X=3, + stride_h=1, + stride_w=1, + pad_h=1, + pad_w=1, + direction=args.variant, + ) + prob.print_problem() + + inp = np.random.uniform(-0.5, 0.5, prob.input_shape()).astype(np.float16) + wei = np.random.uniform(-0.5, 0.5, prob.weight_shape()).astype(np.float16) + + res = runner.run(inp, wei, prob) + if not res.success: + print(f" GPU execution failed: {res.error}") + runner.cleanup() + return 1 + + print(f" Time: {res.time_ms:.4f} ms") + print(f" TFLOPS: {res.tflops:.2f}") + print( + f" Output: shape={res.output.shape}, range=[{res.output.min():.3f}, {res.output.max():.3f}]" + ) + + # ========================================================================= + # Step 6: CPU reference (forward 2D only) + # ========================================================================= + verified = False + if args.variant == "forward" and args.ndim == 2: + print("\n--- Step 6: CPU Reference Verification ---") + ref = cpu_conv2d_fwd(inp, wei, prob) + gpu_f32 = res.output.astype(np.float32) + diff = np.abs(gpu_f32 - ref) + max_abs = diff.max() + max_rel = (diff / (np.abs(ref) + 1e-6)).max() + match = np.allclose(gpu_f32, ref, atol=0.05, rtol=0.05) + print(f" max_abs_diff: {max_abs:.6f}") + print(f" max_rel_diff: {max_rel:.6f}") + print(f" Match: {match}") + verified = match + + runner.cleanup() + + # Summary + print("\n" + "=" * 70) + status = ( + "PASS" if res.success and (verified or args.variant != "forward") else "FAIL" + ) + print(f" Status: {status}") + print( + f" {config_minimal.name} | {prob.gflops:.2f} GFLOPs | {res.tflops:.2f} TFLOPS" + ) + print(f" JIT build time: {jit_build_s:.3f} s") + print(f" Registry: {len(registry)} configs (3 patterns demonstrated)") + print("=" * 70) + return 0 if status == "PASS" else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/grouped_conv/python/02_forward.py b/dispatcher/examples/grouped_conv/python/02_forward.py new file mode 100644 index 0000000000..8f59db05a1 --- /dev/null +++ b/dispatcher/examples/grouped_conv/python/02_forward.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 02: Forward Convolution (2D + 3D) + +Declares forward kernels with explicit tile/wave/warp/pipeline parameters, +builds a registry, JIT compiles, runs on GPU, and validates against CPU reference. + +Usage: + python3 02_forward.py + python3 02_forward.py --arch gfx942 +""" + +import sys +import argparse +import time +import numpy as np +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) + +from grouped_conv_utils import ( + GroupedConvKernelConfig, + GroupedConvProblem, + GroupedConvRegistry, + detect_gpu_arch, +) + + +def cpu_conv2d_fwd(inp, wei, prob): + """Naive CPU reference: 2D forward, NHWGC layout.""" + N, Hi, Wi, G, C = inp.shape + _, Kpg, Y, X, _ = wei.shape + Ho, Wo = prob.Ho, prob.Wo + out = np.zeros((N, Ho, Wo, G, Kpg), dtype=np.float32) + for n in range(N): + for g in range(G): + for ho in range(Ho): + for wo in range(Wo): + for k in range(Kpg): + s = 0.0 + for y in range(Y): + for x in range(X): + hi = ho * prob.stride_h - prob.pad_h + y + wi = wo * prob.stride_w - prob.pad_w + x + if 0 <= hi < Hi and 0 <= wi < Wi: + for c in range(C): + s += float(inp[n, hi, wi, g, c]) * float( + wei[g, k, y, x, c] + ) + out[n, ho, wo, g, k] = s + return out + + +def main(): + parser = argparse.ArgumentParser(description="Forward Convolution (2D + 3D)") + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument("--dtype", default="fp16", choices=["fp16", "bf16"]) + parser.add_argument( + "--workers", type=int, default=0, help="Max JIT workers (0=auto)" + ) + args = parser.parse_args() + + arch = args.arch + print("=" * 70) + print("Example 02: Forward Convolution (2D + 3D)") + print("=" * 70) + print(f"\n Arch: {arch}, Dtype: {args.dtype}") + + # ========================================================================= + # Step 1: Declare forward kernels with explicit parameters + # ========================================================================= + print("\n--- Step 1: Declare Forward Kernels ---") + reg = GroupedConvRegistry("forward_conv") + + # Forward 2D: compv4, 128x128 tile, wave 2x2x1, warp 32x32x16 + reg.add( + GroupedConvKernelConfig( + variant="forward", + ndim_spatial=2, + arch=arch, + dtype=args.dtype, + tile_m=1, + tile_n=128, + tile_k=128, + wave_m=2, + wave_n=2, + wave_k=1, + warp_tile_m=32, + warp_tile_n=32, + warp_tile_k=16, + pipeline="compv4", + scheduler="intrawave", + epilogue="cshuffle", + vector_size_a=4, + vector_size_b=8, + vector_size_c=8, + block_per_cu=1, + ) + ) + # Forward 3D: compv3, 64x64 tile, wave 1x4x1, warp 16x16x32 + reg.add( + GroupedConvKernelConfig( + variant="forward", + ndim_spatial=3, + arch=arch, + dtype=args.dtype, + tile_m=1, + tile_n=64, + tile_k=64, + wave_m=1, + wave_n=4, + wave_k=1, + warp_tile_m=16, + warp_tile_n=16, + warp_tile_k=32, + pipeline="compv3", + scheduler="intrawave", + epilogue="cshuffle", + vector_size_a=4, + vector_size_b=8, + vector_size_c=8, + block_per_cu=1, + ) + ) + reg.print_registry() + + # ========================================================================= + # Step 2: JIT build via registry + # ========================================================================= + print("\n--- Step 2: JIT Build ---") + workers = args.workers if args.workers > 0 else None + t0 = time.perf_counter() + runners = reg.build(verbose=False, max_workers=workers) + jit_s = time.perf_counter() - t0 + print(f" Built {len(runners)} runners in {jit_s:.1f}s") + + for key in [("forward", 2), ("forward", 3)]: + tag = "OK" if key in runners else "FAILED" + print(f" {key[0]} {key[1]}D: {tag}") + + if ("forward", 2) not in runners: + print(" ERROR: forward 2D JIT failed") + return 1 + + np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32 + + # ========================================================================= + # Step 3: Forward 2D -- GPU + CPU reference + # ========================================================================= + print("\n--- Step 3: Forward 2D ---") + prob_2d = GroupedConvProblem( + N=1, C=64, K=64, Hi=8, Wi=8, Y=3, X=3, pad_h=1, pad_w=1, direction="forward" + ) + prob_2d.print_problem() + + x = np.random.uniform(-0.5, 0.5, prob_2d.input_shape()).astype(np_dtype) + w = np.random.uniform(-0.5, 0.5, prob_2d.weight_shape()).astype(np_dtype) + + res = runners[("forward", 2)].run(x, w, prob_2d) + print(f" Time: {res.time_ms:.4f} ms") + print(f" TFLOPS: {res.tflops:.2f}") + print( + f" Output: shape={res.output.shape}, nonzero={np.count_nonzero(res.output)}/{res.output.size}" + ) + + ref = cpu_conv2d_fwd(x, w, prob_2d) + diff = np.abs(res.output.astype(np.float32) - ref) + match_2d = np.allclose(res.output.astype(np.float32), ref, atol=0.05) + print(f" CPU ref: max_abs={diff.max():.6f}, match={match_2d}") + + # ========================================================================= + # Step 4: Forward 3D -- GPU + non-zero check + # ========================================================================= + ok_3d = True + if ("forward", 3) in runners: + print("\n--- Step 4: Forward 3D ---") + prob_3d = GroupedConvProblem( + N=1, + C=64, + K=64, + Di=8, + Hi=8, + Wi=8, + Z=3, + Y=3, + X=3, + pad_d=1, + pad_h=1, + pad_w=1, + direction="forward", + ) + prob_3d.print_problem() + + x3 = np.random.uniform(-0.5, 0.5, prob_3d.input_shape()).astype(np_dtype) + w3 = np.random.uniform(-0.5, 0.5, prob_3d.weight_shape()).astype(np_dtype) + + res3 = runners[("forward", 3)].run(x3, w3, prob_3d) + nz = np.count_nonzero(res3.output) + ok_3d = res3.success and nz > 0 + print(f" Time: {res3.time_ms:.4f} ms") + print(f" TFLOPS: {res3.tflops:.2f}") + print(f" NonZero: {nz}/{res3.output.size}") + + for r in runners.values(): + r.cleanup() + + passed = res.success and match_2d and ok_3d + print("\n" + "=" * 70) + print(f" Forward 2D: {'PASS' if match_2d else 'FAIL'} (CPU validated)") + print(f" Forward 3D: {'PASS' if ok_3d else 'FAIL'} (non-zero check)") + print(f" JIT build: {jit_s:.1f}s") + print(f" Status: {'PASS' if passed else 'FAIL'}") + print("=" * 70) + return 0 if passed else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/grouped_conv/python/03_bwd_data.py b/dispatcher/examples/grouped_conv/python/03_bwd_data.py new file mode 100644 index 0000000000..a000ba7c96 --- /dev/null +++ b/dispatcher/examples/grouped_conv/python/03_bwd_data.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 03: Backward Data Convolution (2D + 3D) + +dX = ConvBwdData(dY, W) + +Declares backward-data kernels with explicit parameters, +builds a registry, JIT compiles, runs on GPU, and validates +against a CPU reference. + +Usage: + python3 03_bwd_data.py +""" + +import sys +import argparse +import time +import numpy as np +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) + +from grouped_conv_utils import ( + GroupedConvKernelConfig, + GroupedConvProblem, + GroupedConvRegistry, + detect_gpu_arch, +) + + +def cpu_conv2d_bwd_data(dy, wei, prob): + """CPU ref: compute dX from dY and W.""" + N, Ho, Wo, G, Kpg = dy.shape + _, _, Y, X, C = wei.shape + Hi, Wi = prob.Hi, prob.Wi + dx = np.zeros((N, Hi, Wi, G, C), dtype=np.float32) + for n in range(N): + for g in range(G): + for hi in range(Hi): + for wi in range(Wi): + for c in range(C): + s = 0.0 + for y in range(Y): + for x in range(X): + ho = hi + prob.pad_h - y + wo = wi + prob.pad_w - x + if ho % prob.stride_h == 0 and wo % prob.stride_w == 0: + ho //= prob.stride_h + wo //= prob.stride_w + if 0 <= ho < Ho and 0 <= wo < Wo: + for k in range(Kpg): + s += float(dy[n, ho, wo, g, k]) * float( + wei[g, k, y, x, c] + ) + dx[n, hi, wi, g, c] = s + return dx + + +def main(): + parser = argparse.ArgumentParser(description="Backward Data (2D + 3D)") + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument("--dtype", default="fp16", choices=["fp16", "bf16"]) + parser.add_argument("--workers", type=int, default=0) + args = parser.parse_args() + + arch = args.arch + print("=" * 70) + print("Example 03: Backward Data Convolution (2D + 3D)") + print("=" * 70) + print(f"\n Arch: {arch}, Dtype: {args.dtype}") + print(" dX = ConvBwdData(dY, W)") + + # ========================================================================= + # Step 1: Declare bwd_data kernels + # ========================================================================= + print("\n--- Step 1: Declare BwdData Kernels ---") + reg = GroupedConvRegistry("bwd_data_conv") + + # BwdData 2D: compv3, 128x128 tile + reg.add( + GroupedConvKernelConfig( + variant="bwd_data", + ndim_spatial=2, + arch=arch, + dtype=args.dtype, + tile_m=1, + tile_n=128, + tile_k=128, + wave_m=2, + wave_n=2, + wave_k=1, + warp_tile_m=32, + warp_tile_n=32, + warp_tile_k=16, + pipeline="compv3", + scheduler="intrawave", + epilogue="cshuffle", + vector_size_a=4, + vector_size_b=8, + vector_size_c=8, + block_per_cu=1, + ) + ) + # BwdData 3D: compv3, 64x64 tile + reg.add( + GroupedConvKernelConfig( + variant="bwd_data", + ndim_spatial=3, + arch=arch, + dtype=args.dtype, + tile_m=1, + tile_n=64, + tile_k=64, + wave_m=1, + wave_n=4, + wave_k=1, + warp_tile_m=16, + warp_tile_n=16, + warp_tile_k=32, + pipeline="compv3", + scheduler="intrawave", + epilogue="cshuffle", + vector_size_a=4, + vector_size_b=8, + vector_size_c=8, + block_per_cu=1, + ) + ) + reg.print_registry() + + # ========================================================================= + # Step 2: JIT build + # ========================================================================= + print("\n--- Step 2: JIT Build ---") + workers = args.workers if args.workers > 0 else None + t0 = time.perf_counter() + runners = reg.build(verbose=False, max_workers=workers) + jit_s = time.perf_counter() - t0 + print(f" Built {len(runners)} runners in {jit_s:.1f}s") + + if ("bwd_data", 2) not in runners: + print(" ERROR: bwd_data 2D JIT failed") + return 1 + + np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32 + + # ========================================================================= + # Step 3: BwdData 2D -- GPU + CPU reference + # ========================================================================= + print("\n--- Step 3: Backward Data 2D ---") + prob = GroupedConvProblem( + N=1, C=32, K=32, Hi=8, Wi=8, Y=3, X=3, pad_h=1, pad_w=1, direction="bwd_data" + ) + prob.print_problem() + + dy = np.random.uniform(-0.5, 0.5, prob.output_shape()).astype(np_dtype) + w = np.random.uniform(-0.5, 0.5, prob.weight_shape()).astype(np_dtype) + + res = runners[("bwd_data", 2)].run(dy, w, prob) + print(f" Time: {res.time_ms:.4f} ms") + print(f" TFLOPS: {res.tflops:.2f}") + print(f" NonZero: {np.count_nonzero(res.output)}/{res.output.size}") + + ref = cpu_conv2d_bwd_data(dy, w, prob) + diff = np.abs(res.output.astype(np.float32) - ref) + match_2d = np.allclose(res.output.astype(np.float32), ref, atol=0.1) + print(f" CPU ref: max_abs={diff.max():.6f}, match={match_2d}") + + # ========================================================================= + # Step 4: BwdData 3D -- GPU + non-zero check + # ========================================================================= + ok_3d = True + if ("bwd_data", 3) in runners: + print("\n--- Step 4: Backward Data 3D ---") + prob3 = GroupedConvProblem( + N=1, + C=32, + K=32, + Di=6, + Hi=6, + Wi=6, + Z=3, + Y=3, + X=3, + pad_d=1, + pad_h=1, + pad_w=1, + direction="bwd_data", + ) + dy3 = np.random.uniform(-0.5, 0.5, prob3.output_shape()).astype(np_dtype) + w3 = np.random.uniform(-0.5, 0.5, prob3.weight_shape()).astype(np_dtype) + res3 = runners[("bwd_data", 3)].run(dy3, w3, prob3) + nz = np.count_nonzero(res3.output) + ok_3d = res3.success and nz > 0 + print(f" Time: {res3.time_ms:.4f} ms, NonZero: {nz}/{res3.output.size}") + + for r in runners.values(): + r.cleanup() + + passed = res.success and match_2d and ok_3d + print("\n" + "=" * 70) + print(f" BwdData 2D: {'PASS' if match_2d else 'FAIL'} (CPU validated)") + print(f" BwdData 3D: {'PASS' if ok_3d else 'FAIL'}") + print(f" Status: {'PASS' if passed else 'FAIL'}") + print("=" * 70) + return 0 if passed else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/grouped_conv/python/04_bwd_weight.py b/dispatcher/examples/grouped_conv/python/04_bwd_weight.py new file mode 100644 index 0000000000..48e50cd4a9 --- /dev/null +++ b/dispatcher/examples/grouped_conv/python/04_bwd_weight.py @@ -0,0 +1,224 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 04: Backward Weight Convolution (2D + 3D) + +dW = ConvBwdWeight(X, dY) + +Declares backward-weight kernels with explicit parameters, +builds a registry, JIT compiles, runs on GPU, and validates +against a CPU reference. + +Usage: + python3 04_bwd_weight.py +""" + +import sys +import argparse +import time +import numpy as np +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) + +from grouped_conv_utils import ( + GroupedConvKernelConfig, + GroupedConvProblem, + GroupedConvRegistry, + detect_gpu_arch, +) + + +def cpu_conv2d_bwd_weight(x, dy, prob): + """CPU ref: compute dW from X and dY.""" + N, Hi, Wi, G, C = x.shape + _, Ho, Wo, _, Kpg = dy.shape + Y, X_ = prob.Y, prob.X + dw = np.zeros((G, Kpg, Y, X_, C), dtype=np.float32) + for g in range(G): + for k in range(Kpg): + for y in range(Y): + for xf in range(X_): + for c in range(C): + s = 0.0 + for n in range(N): + for ho in range(Ho): + for wo in range(Wo): + hi = ho * prob.stride_h - prob.pad_h + y + wi = wo * prob.stride_w - prob.pad_w + xf + if 0 <= hi < Hi and 0 <= wi < Wi: + s += float(x[n, hi, wi, g, c]) * float( + dy[n, ho, wo, g, k] + ) + dw[g, k, y, xf, c] = s + return dw + + +def main(): + parser = argparse.ArgumentParser(description="Backward Weight (2D + 3D)") + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument("--dtype", default="fp16", choices=["fp16", "bf16"]) + parser.add_argument("--workers", type=int, default=0) + parser.add_argument( + "--split-k", type=int, default=1, help="Split-K factor for bwd_weight (k_batch)" + ) + args = parser.parse_args() + + arch = args.arch + print("=" * 70) + print("Example 04: Backward Weight Convolution (2D + 3D)") + print("=" * 70) + print(f"\n Arch: {arch}, Dtype: {args.dtype}") + print(" dW = ConvBwdWeight(X, dY)") + + # ========================================================================= + # Step 1: Declare bwd_weight kernels + # ========================================================================= + print("\n--- Step 1: Declare BwdWeight Kernels ---") + reg = GroupedConvRegistry("bwd_weight_conv") + + # BwdWeight 2D: compv3, 128x128 tile + reg.add( + GroupedConvKernelConfig( + variant="bwd_weight", + ndim_spatial=2, + arch=arch, + dtype=args.dtype, + tile_m=1, + tile_n=128, + tile_k=128, + wave_m=2, + wave_n=2, + wave_k=1, + warp_tile_m=32, + warp_tile_n=32, + warp_tile_k=16, + pipeline="compv3", + scheduler="intrawave", + epilogue="cshuffle", + vector_size_a=4, + vector_size_b=8, + vector_size_c=8, + block_per_cu=1, + ) + ) + # BwdWeight 3D: compv3, 64x64 tile + reg.add( + GroupedConvKernelConfig( + variant="bwd_weight", + ndim_spatial=3, + arch=arch, + dtype=args.dtype, + tile_m=1, + tile_n=64, + tile_k=64, + wave_m=1, + wave_n=4, + wave_k=1, + warp_tile_m=16, + warp_tile_n=16, + warp_tile_k=32, + pipeline="compv3", + scheduler="intrawave", + epilogue="cshuffle", + vector_size_a=4, + vector_size_b=8, + vector_size_c=8, + block_per_cu=1, + ) + ) + reg.print_registry() + + # ========================================================================= + # Step 2: JIT build + # ========================================================================= + print("\n--- Step 2: JIT Build ---") + workers = args.workers if args.workers > 0 else None + t0 = time.perf_counter() + runners = reg.build(verbose=False, max_workers=workers) + jit_s = time.perf_counter() - t0 + print(f" Built {len(runners)} runners in {jit_s:.1f}s") + + if ("bwd_weight", 2) not in runners: + print(" ERROR: bwd_weight 2D JIT failed") + return 1 + + np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32 + + # ========================================================================= + # Step 3: BwdWeight 2D -- GPU + CPU reference + # ========================================================================= + print("\n--- Step 3: Backward Weight 2D ---") + prob = GroupedConvProblem( + N=1, + C=32, + K=32, + Hi=8, + Wi=8, + Y=3, + X=3, + pad_h=1, + pad_w=1, + direction="bwd_weight", + split_k=args.split_k, + ) + prob.print_problem() + + x = np.random.uniform(-0.5, 0.5, prob.input_shape()).astype(np_dtype) + dy = np.random.uniform(-0.5, 0.5, prob.output_shape()).astype(np_dtype) + + res = runners[("bwd_weight", 2)].run(x, dy, prob) + print(f" Time: {res.time_ms:.4f} ms") + print(f" TFLOPS: {res.tflops:.2f}") + print(f" NonZero: {np.count_nonzero(res.output)}/{res.output.size}") + + ref = cpu_conv2d_bwd_weight(x, dy, prob) + diff = np.abs(res.output.astype(np.float32) - ref) + match_2d = np.allclose(res.output.astype(np.float32), ref, atol=0.5) + print(f" CPU ref: max_abs={diff.max():.6f}, match={match_2d}") + + # ========================================================================= + # Step 4: BwdWeight 3D -- GPU + non-zero check + # ========================================================================= + ok_3d = True + if ("bwd_weight", 3) in runners: + print("\n--- Step 4: Backward Weight 3D ---") + prob3 = GroupedConvProblem( + N=1, + C=32, + K=32, + Di=6, + Hi=6, + Wi=6, + Z=3, + Y=3, + X=3, + pad_d=1, + pad_h=1, + pad_w=1, + direction="bwd_weight", + ) + x3 = np.random.uniform(-0.5, 0.5, prob3.input_shape()).astype(np_dtype) + dy3 = np.random.uniform(-0.5, 0.5, prob3.output_shape()).astype(np_dtype) + res3 = runners[("bwd_weight", 3)].run(x3, dy3, prob3) + nz = np.count_nonzero(res3.output) + ok_3d = res3.success and nz > 0 + print(f" Time: {res3.time_ms:.4f} ms, NonZero: {nz}/{res3.output.size}") + + for r in runners.values(): + r.cleanup() + + passed = res.success and match_2d and ok_3d + print("\n" + "=" * 70) + print(f" BwdWeight 2D: {'PASS' if match_2d else 'FAIL'} (CPU validated)") + print(f" BwdWeight 3D: {'PASS' if ok_3d else 'FAIL'}") + print(f" Status: {'PASS' if passed else 'FAIL'}") + print("=" * 70) + return 0 if passed else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/grouped_conv/python/05_benchmark.py b/dispatcher/examples/grouped_conv/python/05_benchmark.py new file mode 100644 index 0000000000..9166ab988e --- /dev/null +++ b/dispatcher/examples/grouped_conv/python/05_benchmark.py @@ -0,0 +1,318 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 05: Multi-Problem GPU Benchmark + +Declares kernels with explicit tile/wave/warp/pipeline parameters for +all directions, builds registries, JIT compiles, and benchmarks across +ResNet-like problem sizes with configurable warmup/repeat. + +Usage: + python3 05_benchmark.py + python3 05_benchmark.py --warmup 3 --repeat 10 + python3 05_benchmark.py --workers 4 +""" + +import sys +import argparse +import time +import numpy as np +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) + +from grouped_conv_utils import ( + GroupedConvKernelConfig, + GroupedConvProblem, + GroupedConvRegistry, + detect_gpu_arch, +) + + +def compute_bytes(prob, dtype_bytes=2): + in_elems = 1 + for d in prob.input_shape(): + in_elems *= d + wei_elems = 1 + for d in prob.weight_shape(): + wei_elems *= d + out_elems = 1 + for d in prob.output_shape(): + out_elems *= d + return (in_elems + wei_elems + out_elems) * dtype_bytes + + +def main(): + parser = argparse.ArgumentParser(description="Multi-Problem GPU Benchmark") + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument("--dtype", default="fp16", choices=["fp16", "bf16"]) + parser.add_argument("--warmup", type=int, default=3, help="Warmup iterations") + parser.add_argument("--repeat", type=int, default=5, help="Benchmark iterations") + parser.add_argument( + "--workers", type=int, default=0, help="Max JIT workers (0=auto)" + ) + args = parser.parse_args() + + print("=" * 70) + print("Example 05: Multi-Problem GPU Benchmark") + print("=" * 70) + print(f"\n Arch: {args.arch}, Dtype: {args.dtype}") + print(f" Warmup: {args.warmup}, Repeat: {args.repeat}") + + # ========================================================================= + # Step 1: Declare all kernels with explicit parameters + # ========================================================================= + print("\n--- Step 1: Declare Kernels ---") + reg = GroupedConvRegistry("benchmark") + + # Forward 2D: compv4, 128x128 tile + reg.add( + GroupedConvKernelConfig( + variant="forward", + ndim_spatial=2, + arch=args.arch, + dtype=args.dtype, + tile_m=1, + tile_n=128, + tile_k=128, + wave_m=2, + wave_n=2, + wave_k=1, + warp_tile_m=32, + warp_tile_n=32, + warp_tile_k=16, + pipeline="compv4", + scheduler="intrawave", + epilogue="cshuffle", + vector_size_a=4, + vector_size_b=8, + vector_size_c=8, + block_per_cu=1, + ) + ) + # Forward 3D: compv3, 64x64 tile + reg.add( + GroupedConvKernelConfig( + variant="forward", + ndim_spatial=3, + arch=args.arch, + dtype=args.dtype, + tile_m=1, + tile_n=64, + tile_k=64, + wave_m=1, + wave_n=4, + wave_k=1, + warp_tile_m=16, + warp_tile_n=16, + warp_tile_k=32, + pipeline="compv3", + scheduler="intrawave", + epilogue="cshuffle", + vector_size_a=4, + vector_size_b=8, + vector_size_c=8, + block_per_cu=1, + ) + ) + # BwdData 2D: compv3, 128x128 tile + reg.add( + GroupedConvKernelConfig( + variant="bwd_data", + ndim_spatial=2, + arch=args.arch, + dtype=args.dtype, + tile_m=1, + tile_n=128, + tile_k=128, + wave_m=2, + wave_n=2, + wave_k=1, + warp_tile_m=32, + warp_tile_n=32, + warp_tile_k=16, + pipeline="compv3", + scheduler="intrawave", + epilogue="cshuffle", + vector_size_a=4, + vector_size_b=8, + vector_size_c=8, + block_per_cu=1, + ) + ) + # BwdWeight 2D: compv3, 128x128 tile + reg.add( + GroupedConvKernelConfig( + variant="bwd_weight", + ndim_spatial=2, + arch=args.arch, + dtype=args.dtype, + tile_m=1, + tile_n=128, + tile_k=128, + wave_m=2, + wave_n=2, + wave_k=1, + warp_tile_m=32, + warp_tile_n=32, + warp_tile_k=16, + pipeline="compv3", + scheduler="intrawave", + epilogue="cshuffle", + vector_size_a=4, + vector_size_b=8, + vector_size_c=8, + block_per_cu=1, + ) + ) + reg.print_registry() + + # ========================================================================= + # Step 2: JIT build + # ========================================================================= + print("\n--- Step 2: JIT Build ---") + workers = args.workers if args.workers > 0 else None + t0 = time.perf_counter() + runner_by_key = reg.build(verbose=False, max_workers=workers) + jit_s = time.perf_counter() - t0 + + for key in [("forward", 2), ("forward", 3), ("bwd_data", 2), ("bwd_weight", 2)]: + tag = "OK" if key in runner_by_key else "FAILED" + print(f" {key[0]:12s} {key[1]}D: {tag}") + print(f" JIT build time: {jit_s:.3f} s") + + missing = [ + k + for k in [("forward", 2), ("forward", 3), ("bwd_data", 2), ("bwd_weight", 2)] + if k not in runner_by_key + ] + if missing: + print(f"\n ERROR: missing {missing}") + return 1 + + np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32 + + def bench_run(runner, inp, wei, prob): + for _ in range(args.warmup): + runner.run(inp, wei, prob) + times = [] + for _ in range(args.repeat): + r = runner.run(inp, wei, prob) + if r.success: + times.append(r.time_ms) + if not times: + return 0.0, 0.0 + return min(times), sum(times) / len(times) + + # ========================================================================= + # Step 3: 2D Forward benchmark + # ========================================================================= + print("\n--- Step 3: Forward 2D Benchmark ---") + print( + f"{'Problem':<18} {'N':>3} {'C':>4} {'K':>4} {'H':>3} {'W':>3} " + f"{'F':>3} {'Min(ms)':>9} {'Avg(ms)':>9} {'TFLOPS':>8} {'GB/s':>8}" + ) + print("-" * 85) + + all_ok = True + for label, n, c, k, h, w, y, x, s, p in [ + ("ResNet-stage2", 1, 64, 64, 56, 56, 3, 3, 1, 1), + ("ResNet-stage3", 1, 128, 128, 28, 28, 3, 3, 1, 1), + ("ResNet-stage4", 1, 256, 256, 14, 14, 3, 3, 1, 1), + ("ResNet-stage5", 1, 512, 512, 7, 7, 3, 3, 1, 1), + ("Pointwise-1x1", 1, 256, 256, 56, 56, 1, 1, 1, 0), + ("Batch-8", 8, 64, 128, 56, 56, 3, 3, 1, 1), + ("Batch-32", 32, 64, 128, 56, 56, 3, 3, 1, 1), + ]: + prob = GroupedConvProblem( + N=n, + C=c, + K=k, + Hi=h, + Wi=w, + Y=y, + X=x, + stride_h=s, + stride_w=s, + pad_h=p, + pad_w=p, + direction="forward", + ) + inp = np.random.uniform(-0.3, 0.3, prob.input_shape()).astype(np_dtype) + wei = np.random.uniform(-0.3, 0.3, prob.weight_shape()).astype(np_dtype) + min_ms, avg_ms = bench_run(runner_by_key[("forward", 2)], inp, wei, prob) + if avg_ms > 0: + tflops = prob.flops / (avg_ms * 1e9) + bw = compute_bytes(prob) / (avg_ms * 1e6) + print( + f"{label:<18} {n:>3} {c:>4} {k:>4} {h:>3} {w:>3} " + f"{y}x{x} {min_ms:>9.4f} {avg_ms:>9.4f} {tflops:>8.2f} {bw:>8.1f}" + ) + else: + all_ok = False + + # ========================================================================= + # Step 4: 3D Forward + # ========================================================================= + print("\n--- Step 4: Forward 3D ---") + for label, n, c, k, d, h, w, z, y, x in [ + ("3D-small", 1, 64, 64, 8, 16, 16, 3, 3, 3), + ("3D-medium", 1, 64, 128, 16, 32, 32, 3, 3, 3), + ]: + prob = GroupedConvProblem( + N=n, C=c, K=k, Di=d, Hi=h, Wi=w, Z=z, Y=y, X=x, direction="forward" + ) + inp = np.random.uniform(-0.3, 0.3, prob.input_shape()).astype(np_dtype) + wei = np.random.uniform(-0.3, 0.3, prob.weight_shape()).astype(np_dtype) + min_ms, avg_ms = bench_run(runner_by_key[("forward", 3)], inp, wei, prob) + if avg_ms > 0: + tflops = prob.flops / (avg_ms * 1e9) + print(f" {label:<14} {min_ms:.4f} / {avg_ms:.4f} ms {tflops:.2f} TFLOPS") + + # ========================================================================= + # Step 5: Backward directions + # ========================================================================= + print("\n--- Step 5: Backward Directions ---") + for label, direction in [ + ("bwd_data ResNet-s3", "bwd_data"), + ("bwd_weight ResNet-s3", "bwd_weight"), + ]: + prob = GroupedConvProblem( + N=1, + C=128, + K=128, + Hi=28, + Wi=28, + Y=3, + X=3, + stride_h=1, + stride_w=1, + pad_h=1, + pad_w=1, + direction=direction, + ) + inp = np.random.uniform(-0.3, 0.3, prob.input_shape()).astype(np_dtype) + wei = np.random.uniform(-0.3, 0.3, prob.weight_shape()).astype(np_dtype) + min_ms, avg_ms = bench_run(runner_by_key[(direction, 2)], inp, wei, prob) + if avg_ms > 0: + tflops = prob.flops / (avg_ms * 1e9) + print( + f" {label:<14} {direction:>12} {min_ms:.4f} / {avg_ms:.4f} ms {tflops:.2f} TFLOPS" + ) + + for runner in runner_by_key.values(): + runner.cleanup() + + print("\n" + "=" * 70) + print(f" JIT build: {jit_s:.3f} s") + print(f" Warmup: {args.warmup}, Repeat: {args.repeat}") + print(f" Status: {'PASS' if all_ok else 'FAIL'}") + print("=" * 70) + return 0 if all_ok else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/examples/grouped_conv/python/06_registry_json.py b/dispatcher/examples/grouped_conv/python/06_registry_json.py new file mode 100644 index 0000000000..1a3dc854e7 --- /dev/null +++ b/dispatcher/examples/grouped_conv/python/06_registry_json.py @@ -0,0 +1,274 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Example 06: Registry, Heuristic Selection & JSON Export + +Declares multiple kernel configurations with different tile sizes, +builds a registry, demonstrates heuristic runtime kernel selection, +JSON round-trip, and GPU execution. + +Usage: + python3 06_registry_json.py + python3 06_registry_json.py --workers 4 +""" + +import sys +import time +import argparse +import numpy as np +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) + +from grouped_conv_utils import ( + GroupedConvKernelConfig, + GroupedConvProblem, + GroupedConvRegistry, + detect_gpu_arch, +) + + +def conv_heuristic(problem): + spatial = problem.Ho * problem.Wo + if spatial > 400: + return ["256", "128", "64"] + return ["64", "128", "256"] + + +def main(): + parser = argparse.ArgumentParser(description="Registry, Heuristic & JSON") + parser.add_argument("--arch", default=detect_gpu_arch()) + parser.add_argument("--dtype", default="fp16", choices=["fp16", "bf16"]) + parser.add_argument("--workers", type=int, default=0) + args = parser.parse_args() + + arch = args.arch + print("=" * 70) + print("Example 06: Registry, Heuristic Selection & JSON Export") + print("=" * 70) + print(f"\n Arch: {arch}, Dtype: {args.dtype}") + + # Step 1: Declare kernels with full explicit parameters + print("\n--- Step 1: Declare Kernels + Build Registry ---") + reg = GroupedConvRegistry("conv_tiles") + + reg.add( + GroupedConvKernelConfig( + variant="forward", + ndim_spatial=2, + arch=arch, + dtype=args.dtype, + tile_m=1, + tile_n=256, + tile_k=256, + wave_m=2, + wave_n=2, + wave_k=1, + warp_tile_m=32, + warp_tile_n=32, + warp_tile_k=16, + pipeline="compv3", + scheduler="intrawave", + epilogue="cshuffle", + vector_size_a=4, + vector_size_b=8, + vector_size_c=8, + block_per_cu=1, + num_wave_groups=1, + num_groups_to_merge=1, + ) + ) + reg.add( + GroupedConvKernelConfig( + variant="forward", + ndim_spatial=2, + arch=arch, + dtype=args.dtype, + tile_m=1, + tile_n=128, + tile_k=128, + wave_m=2, + wave_n=2, + wave_k=1, + warp_tile_m=32, + warp_tile_n=32, + warp_tile_k=16, + pipeline="compv4", + scheduler="intrawave", + epilogue="cshuffle", + vector_size_a=4, + vector_size_b=8, + vector_size_c=8, + block_per_cu=1, + num_wave_groups=1, + num_groups_to_merge=1, + ) + ) + reg.add( + GroupedConvKernelConfig( + variant="forward", + ndim_spatial=2, + arch=arch, + dtype=args.dtype, + tile_m=1, + tile_n=64, + tile_k=64, + wave_m=1, + wave_n=4, + wave_k=1, + warp_tile_m=16, + warp_tile_n=16, + warp_tile_k=32, + pipeline="compv3", + scheduler="intrawave", + epilogue="cshuffle", + vector_size_a=4, + vector_size_b=8, + vector_size_c=8, + block_per_cu=1, + num_wave_groups=1, + num_groups_to_merge=1, + ) + ) + reg.print_registry() + + # Step 2: Heuristic kernel selection + print("\n--- Step 2: Heuristic Kernel Selection ---") + problems = [ + ( + "small_7x7", + GroupedConvProblem( + N=1, + C=512, + K=512, + Hi=7, + Wi=7, + Y=3, + X=3, + pad_h=1, + pad_w=1, + direction="forward", + ), + ), + ( + "medium_14x14", + GroupedConvProblem( + N=1, + C=256, + K=256, + Hi=14, + Wi=14, + Y=3, + X=3, + pad_h=1, + pad_w=1, + direction="forward", + ), + ), + ( + "large_56x56", + GroupedConvProblem( + N=1, + C=64, + K=128, + Hi=56, + Wi=56, + Y=3, + X=3, + pad_h=1, + pad_w=1, + direction="forward", + ), + ), + ] + print(f" {'Problem':<16} {'Spatial':>8} {'Selected Kernel':<50}") + print(f" {'-' * 74}") + for label, prob in problems: + selected = reg.select(prob, heuristic=conv_heuristic) + spatial = prob.Ho * prob.Wo + sel_name = selected.name if selected else "none" + print(f" {label:<16} {spatial:>8} {sel_name:<50}") + + # Step 3: JSON round-trip + print("\n--- Step 3: JSON Round-Trip ---") + json_str = reg.to_json() + print(f" Exported: {len(json_str)} bytes, {len(reg)} kernels") + imported = GroupedConvRegistry.from_json(json_str) + print(f" Imported: {len(imported)} kernels") + orig = reg.kernels[0] + imp = imported.kernels[0] + rt_ok = ( + orig.vector_size_a == imp.vector_size_a + and orig.block_per_cu == imp.block_per_cu + and orig.tile_n == imp.tile_n + ) + print(f" Full fields round-trip: {'OK' if rt_ok else 'FAIL'}") + + # Step 4: JIT build + GPU execution + print("\n--- Step 4: JIT Build + GPU Execution ---") + workers = args.workers if args.workers > 0 else None + jit_reg = GroupedConvRegistry("jit_conv") + jit_reg.add( + GroupedConvKernelConfig( + variant="forward", + ndim_spatial=2, + arch=arch, + dtype=args.dtype, + tile_m=1, + tile_n=128, + tile_k=128, + wave_m=2, + wave_n=2, + wave_k=1, + warp_tile_m=32, + warp_tile_n=32, + warp_tile_k=16, + pipeline="compv4", + scheduler="intrawave", + epilogue="cshuffle", + vector_size_a=4, + vector_size_b=8, + vector_size_c=8, + ) + ) + t0 = time.perf_counter() + runners = jit_reg.build(verbose=False, max_workers=workers) + jit_s = time.perf_counter() - t0 + + if ("forward", 2) not in runners: + print(" JIT build failed") + return 1 + runner = runners[("forward", 2)] + print(f" JIT build: {jit_s:.3f} s") + print(f" Library: {runner.library_path}") + + prob = GroupedConvProblem( + N=1, C=128, K=128, Hi=16, Wi=16, Y=3, X=3, pad_h=1, pad_w=1, direction="forward" + ) + np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32 + inp = np.random.uniform(-0.3, 0.3, prob.input_shape()).astype(np_dtype) + wei = np.random.uniform(-0.3, 0.3, prob.weight_shape()).astype(np_dtype) + res = runner.run(inp, wei, prob) + runner.cleanup() + + if res.success: + print(f" Time: {res.time_ms:.4f} ms") + print(f" TFLOPS: {res.tflops:.2f}") + print(f" NonZero: {np.count_nonzero(res.output)}/{res.output.size}") + + gpu_ok = res.success + print("\n" + "=" * 70) + print(f" Registry: {len(reg)} kernels (3 tile configs)") + print(" Heuristic: spatial-based selection demonstrated") + print(f" JSON: round-trip {'OK' if rt_ok else 'FAIL'}") + print(f" GPU: {'OK' if gpu_ok else 'FAIL'}") + print(f" Status: {'PASS' if gpu_ok and rt_ok else 'FAIL'}") + print("=" * 70) + return 0 if gpu_ok and rt_ok else 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/include/ck_tile/dispatcher.hpp b/dispatcher/include/ck_tile/dispatcher.hpp index 98d8bb9333..b3d8f10675 100644 --- a/dispatcher/include/ck_tile/dispatcher.hpp +++ b/dispatcher/include/ck_tile/dispatcher.hpp @@ -3,9 +3,17 @@ #pragma once -/// Main dispatcher header - includes all core components -/// Use this for convenient access to the full dispatcher API +/// Full dispatcher header - includes ALL operation types. +/// For minimal includes, use the per-operation headers instead: +/// ck_tile/dispatcher_gemm.hpp -- GEMM only +/// ck_tile/dispatcher_conv.hpp -- Grouped Convolution only +// Core (needed by all ops) +#include "ck_tile/dispatcher/base_registry.hpp" +#include "ck_tile/dispatcher/dispatcher_error.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +// GEMM #include "ck_tile/dispatcher/kernel_key.hpp" #include "ck_tile/dispatcher/kernel_config.hpp" #include "ck_tile/dispatcher/kernel_decl.hpp" @@ -13,7 +21,15 @@ #include "ck_tile/dispatcher/kernel_instance.hpp" #include "ck_tile/dispatcher/registry.hpp" #include "ck_tile/dispatcher/dispatcher.hpp" +#include "ck_tile/dispatcher/json_export.hpp" #include "ck_tile/dispatcher/arch_filter.hpp" #include "ck_tile/dispatcher/backends/tile_backend.hpp" #include "ck_tile/dispatcher/backends/generated_tile_backend.hpp" #include "ck_tile/dispatcher/utils.hpp" + +// Grouped Convolution +#include "ck_tile/dispatcher/grouped_conv_config.hpp" +#include "ck_tile/dispatcher/grouped_conv_problem.hpp" +#include "ck_tile/dispatcher/grouped_conv_kernel_decl.hpp" +#include "ck_tile/dispatcher/grouped_conv_registry.hpp" +#include "ck_tile/dispatcher/grouped_conv_utils.hpp" diff --git a/dispatcher/include/ck_tile/dispatcher/README.md b/dispatcher/include/ck_tile/dispatcher/README.md index db3ce996a9..430798aedd 100644 --- a/dispatcher/include/ck_tile/dispatcher/README.md +++ b/dispatcher/include/ck_tile/dispatcher/README.md @@ -1,6 +1,6 @@ # CK Tile Dispatcher - C++ Headers -C++ API for the CK Tile dispatcher. +C++ API for the CK Tile dispatcher (GEMM and Grouped Convolution). > **See also:** [Main Dispatcher README](../../../../README.md) for installation and core concepts. @@ -8,16 +8,25 @@ C++ API for the CK Tile dispatcher. ``` dispatcher/ -├── dispatcher.hpp # Main dispatcher (kernel selection) -├── registry.hpp # Kernel registry (storage & lookup) -├── problem.hpp # Problem specification -├── kernel_key.hpp # Kernel configuration key -├── kernel_instance.hpp # Kernel instance interface -├── utils.hpp # Utilities (timers, GPU buffers) -│ -└── backends/ # Backend implementations - ├── generated_tile_backend.hpp # CK Tile kernels (production) - └── tile_backend.hpp # Tile backend base +|---- dispatcher.hpp # Main include (includes all below) +| +|---- # GEMM Headers +|---- registry.hpp # Kernel registry (storage & lookup) +|---- problem.hpp # GEMM problem specification +|---- kernel_key.hpp # Kernel configuration key +|---- kernel_instance.hpp # Kernel instance interface +|---- utils.hpp # Utilities (timers, GPU buffers) +| +|---- # Grouped Convolution Headers +|---- grouped_conv_config.hpp # GroupedConvDirection, GroupedConvConfig +|---- grouped_conv_problem.hpp # GroupedConvProblem + ProblemBuilder +|---- grouped_conv_kernel_decl.hpp # GroupedConvKernelDecl, DECL_GROUPED_CONV_KERNEL_SET +|---- grouped_conv_registry.hpp # Thread-safe registry with JSON export & filtering +|---- grouped_conv_utils.hpp # Config creators, validation, benchmark utilities +| ++---- backends/ # Backend implementations + |---- generated_tile_backend.hpp # CK Tile kernels (production) + +---- tile_backend.hpp # Tile backend base ``` ## Quick Start @@ -148,6 +157,69 @@ auto kernel = create_generated_tile_kernel< >(key, name); ``` +## Grouped Convolution API + +### GroupedConvProblem (`grouped_conv_problem.hpp`) + +Problem specification with builder pattern: + +```cpp +#include "ck_tile/dispatcher/grouped_conv_problem.hpp" + +using namespace ck_tile::dispatcher; + +auto problem = GroupedConvProblemBuilder() + .n(2).g(1).c(128).k(256) + .input_spatial({28, 28}) + .filter_spatial({3, 3}) + .strides({1, 1}) + .dilations({1, 1}) + .left_pads({1, 1}) + .right_pads({1, 1}) + .build(); + +bool ok = problem.is_valid(); +``` + +### GroupedConvRegistry (`grouped_conv_registry.hpp`) + +Thread-safe registry with JSON export and filtering: + +```cpp +#include "ck_tile/dispatcher/grouped_conv_registry.hpp" + +auto& registry = GroupedConvRegistry::instance(); + +// Thread-safe registration +registry.register_kernel(kernel); + +// JSON export +std::string json = registry.export_json(); +registry.export_json_to_file("kernels.json"); + +// Filtering +auto gfx942_kernels = registry.filter_by_arch("gfx942"); +auto matched = registry.filter([](const auto& k) { return k.is_fwd(); }); +``` + +### DECL_GROUPED_CONV_KERNEL_SET (`grouped_conv_kernel_decl.hpp`) + +Declarative kernel definition: + +```cpp +DECL_GROUPED_CONV_KERNEL_SET(my_conv_kernels, + .add( + GroupedConvSignature().dtype("fp16").layout("nhwgc"), + GroupedConvAlgorithm().tile(128, 128, 32).wave(2, 2, 1) + .warp(32, 32, 16).pipeline("compv4"), + "gfx942" + ) +); + +// Register all matching current arch +DECL_GROUPED_CONV_KERNEL_ALL(all_conv_kernels, "gfx942"); +``` + ## Best Practices 1. Use `Release` build for performance @@ -155,6 +227,8 @@ auto kernel = create_generated_tile_kernel< 3. Use `Priority::High` for hand-tuned kernels 4. Reuse dispatcher instances 5. Clear registry between test runs +6. Use `GroupedConvProblemBuilder` for validated problem construction +7. Leverage `export_json()` for kernel inventory and debugging --- diff --git a/dispatcher/include/ck_tile/dispatcher/backends/generated_conv_backend.hpp b/dispatcher/include/ck_tile/dispatcher/backends/generated_conv_backend.hpp new file mode 100644 index 0000000000..04ee1b2d11 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/backends/generated_conv_backend.hpp @@ -0,0 +1,152 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +// Generated Convolution Kernel Backend +// +// Wraps CK Tile grouped convolution launchers for use through the +// GroupedConvDispatcher. Each generated kernel launcher is wrapped in +// a ConvKernelRunFn that builds the correct host-args type (forward, +// bwd-data, or bwd-weight) and calls Launcher::launch(). + +#pragma once + +#include "ck_tile/dispatcher/grouped_conv_problem.hpp" +#include "ck_tile/dispatcher/grouped_conv_registry.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/convolution_parameter.hpp" +#include "ck_tile/ops/grouped_convolution.hpp" +#include +#include + +namespace ck_tile { +namespace dispatcher { +namespace backends { + +// Buffer context is defined in grouped_conv_registry.hpp (g_conv_dispatch_buffers) +// so there's no circular dependency. + +// Helper: build ck_tile::conv::ConvParam from GroupedConvProblem +inline ck_tile::conv::ConvParam make_conv_param_2d(const GroupedConvProblem& p) +{ + return ck_tile::conv::ConvParam{ + 2, + static_cast(p.G), + static_cast(p.N), + static_cast(p.K), + static_cast(p.C), + {static_cast(p.filter_spatial[1]), + static_cast(p.filter_spatial[2])}, + {static_cast(p.input_spatial[1]), + static_cast(p.input_spatial[2])}, + {static_cast(p.stride[1]), static_cast(p.stride[2])}, + {static_cast(p.dilation[1]), + static_cast(p.dilation[2])}, + {static_cast(p.padding[1]), static_cast(p.padding[2])}, + {static_cast(p.padding[1]), static_cast(p.padding[2])}}; +} + +inline ck_tile::conv::ConvParam make_conv_param_3d(const GroupedConvProblem& p) +{ + return ck_tile::conv::ConvParam{3, + static_cast(p.G), + static_cast(p.N), + static_cast(p.K), + static_cast(p.C), + {static_cast(p.filter_spatial[0]), + static_cast(p.filter_spatial[1]), + static_cast(p.filter_spatial[2])}, + {static_cast(p.input_spatial[0]), + static_cast(p.input_spatial[1]), + static_cast(p.input_spatial[2])}, + {static_cast(p.stride[0]), + static_cast(p.stride[1]), + static_cast(p.stride[2])}, + {static_cast(p.dilation[0]), + static_cast(p.dilation[1]), + static_cast(p.dilation[2])}, + {static_cast(p.padding[0]), + static_cast(p.padding[1]), + static_cast(p.padding[2])}, + {static_cast(p.padding[0]), + static_cast(p.padding[1]), + static_cast(p.padding[2])}}; +} + +// Create a RunFn for a forward convolution launcher (2D or 3D) +template +inline GroupedConvKernelInstance::RunFn make_conv_fwd_run_fn() +{ + return [](const GroupedConvProblem& problem, void* stream) -> float { + auto& ctx = g_conv_dispatch_buffers; + auto param = (NDim == 2) ? make_conv_param_2d(problem) : make_conv_param_3d(problem); + ck_tile::GroupedConvFwdHostArgs<> args( + param, ctx.input_ptr, ctx.weight_ptr, {}, ctx.output_ptr, 1); + ck_tile::stream_config sc; + sc.stream_id_ = reinterpret_cast(stream); + sc.time_kernel_ = ctx.benchmarking; + sc.log_level_ = 0; + sc.cold_niters_ = ctx.benchmarking ? ctx.warmup : 0; + sc.nrepeat_ = ctx.benchmarking ? ctx.repeat : 1; + sc.is_gpu_timer_ = ctx.benchmarking; + return LauncherType::launch(args, sc); + }; +} + +// Create a RunFn for a backward-data convolution launcher. +// Dispatcher convention: run(dY, W, dX, problem) where dX is computed. +// BwdDataHostArgs(param, in_ptr=dX, wei_ptr=W, {}, out_ptr=dY, k_batch) +template +inline GroupedConvKernelInstance::RunFn make_conv_bwd_data_run_fn() +{ + return [](const GroupedConvProblem& problem, void* stream) -> float { + auto& ctx = g_conv_dispatch_buffers; + auto param = (NDim == 2) ? make_conv_param_2d(problem) : make_conv_param_3d(problem); + ck_tile::GroupedConvBwdDataHostArgs args( + param, + ctx.output_ptr, // in_ptr = dX (being computed) + ctx.weight_ptr, // wei_ptr = W + {}, + ctx.input_ptr, // out_ptr = dY (gradient from next layer) + 1); + ck_tile::stream_config sc; + sc.stream_id_ = reinterpret_cast(stream); + sc.time_kernel_ = ctx.benchmarking; + sc.log_level_ = 0; + sc.cold_niters_ = ctx.benchmarking ? ctx.warmup : 0; + sc.nrepeat_ = ctx.benchmarking ? ctx.repeat : 1; + sc.is_gpu_timer_ = ctx.benchmarking; + return LauncherType::launch(args, sc); + }; +} + +// Create a RunFn for a backward-weight convolution launcher. +// Dispatcher convention: run(X, dY, dW, problem) where dW is computed. +// BwdWeightHostArgs(param, in_ptr=X, wei_ptr=dW, {}, out_ptr=dY, k_batch) +template +inline GroupedConvKernelInstance::RunFn make_conv_bwd_weight_run_fn() +{ + return [](const GroupedConvProblem& problem, void* stream) -> float { + auto& ctx = g_conv_dispatch_buffers; + auto param = (NDim == 2) ? make_conv_param_2d(problem) : make_conv_param_3d(problem); + const int k_batch = (ctx.split_k > 1) ? ctx.split_k : 1; + ck_tile::GroupedConvBwdWeightHostArgs args(param, + ctx.input_ptr, // in_ptr = X + ctx.output_ptr, // wei_ptr = dW (being computed) + {}, + ctx.weight_ptr, // out_ptr = dY + k_batch); + ck_tile::stream_config sc; + sc.stream_id_ = reinterpret_cast(stream); + sc.time_kernel_ = ctx.benchmarking; + sc.log_level_ = 0; + sc.cold_niters_ = ctx.benchmarking ? ctx.warmup : 0; + sc.nrepeat_ = ctx.benchmarking ? ctx.repeat : 1; + sc.is_gpu_timer_ = ctx.benchmarking; + return LauncherType::launch(args, sc); + }; +} + +} // namespace backends +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/base_registry.hpp b/dispatcher/include/ck_tile/dispatcher/base_registry.hpp new file mode 100644 index 0000000000..2bb940c320 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/base_registry.hpp @@ -0,0 +1,199 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { + +/// Shared priority enum used by all registry types +enum class Priority +{ + Low = 0, + Normal = 1, + High = 2 +}; + +/// BaseRegistry: Thread-safe, priority-aware kernel storage shared by GEMM and Conv registries. +/// +/// Template Parameters: +/// Derived - CRTP derived class (e.g., Registry, ConvRegistry) +/// KeyType - primary key type (std::string for GEMM, ConvKernelKey for Conv) +/// InstanceType - kernel instance type (KernelInstance, ConvKernelInstance) +/// KeyHash - hash functor for KeyType (defaults to std::hash) +template > +class BaseRegistry +{ + public: + using InstancePtr = std::shared_ptr; + + struct Entry + { + InstancePtr instance; + Priority priority; + }; + + BaseRegistry() = default; + virtual ~BaseRegistry() = default; + + BaseRegistry(BaseRegistry&& other) noexcept + { + std::lock_guard lock(other.mutex_); + entries_ = std::move(other.entries_); + name_ = std::move(other.name_); + } + + BaseRegistry& operator=(BaseRegistry&& other) noexcept + { + if(this != &other) + { + std::scoped_lock lock(mutex_, other.mutex_); + entries_ = std::move(other.entries_); + name_ = std::move(other.name_); + } + return *this; + } + + BaseRegistry(const BaseRegistry&) = delete; + BaseRegistry& operator=(const BaseRegistry&) = delete; + + /// Register a kernel. If the key already exists, the new entry replaces it + /// unless the existing entry has strictly higher priority. + /// Same-priority registration overwrites (last-writer-wins at equal priority). + bool + register_kernel(const KeyType& key, InstancePtr instance, Priority priority = Priority::Normal) + { + std::lock_guard lock(mutex_); + auto it = entries_.find(key); + if(it != entries_.end() && it->second.priority > priority) + { + return false; + } + entries_[key] = Entry{std::move(instance), priority}; + return true; + } + + [[nodiscard]] std::size_t size() const + { + std::lock_guard lock(mutex_); + return entries_.size(); + } + + [[nodiscard]] bool empty() const + { + std::lock_guard lock(mutex_); + return entries_.empty(); + } + + void clear() + { + std::lock_guard lock(mutex_); + entries_.clear(); + } + + [[nodiscard]] std::string get_name() const + { + std::lock_guard lock(mutex_); + return name_; // return by value to avoid dangling reference + } + + void set_name(const std::string& name) + { + std::lock_guard lock(mutex_); + name_ = name; + } + + [[nodiscard]] std::vector get_all_instances() const + { + std::lock_guard lock(mutex_); + std::vector result; + result.reserve(entries_.size()); + for(const auto& [key, entry] : entries_) + { + result.push_back(entry.instance); + } + return result; + } + + std::size_t merge_from(const BaseRegistry& other, Priority priority = Priority::Normal) + { + std::scoped_lock lock(mutex_, other.mutex_); + std::size_t merged = 0; + for(const auto& [key, entry] : other.entries_) + { + auto it = entries_.find(key); + if(it == entries_.end() || it->second.priority <= priority) + { + entries_[key] = Entry{entry.instance, priority}; + ++merged; + } + } + return merged; + } + + /// Enable automatic JSON export after every kernel registration. + /// Requires the derived class to implement export_json_to_file(path, stats). + void enable_auto_export(const std::string& path, + bool include_statistics = true, + bool export_on_every_registration = true) + { + std::lock_guard lock(mutex_); + auto_export_path_ = path; + auto_export_stats_ = include_statistics; + auto_export_on_register_ = export_on_every_registration; + auto_export_enabled_.store(true, std::memory_order_release); + } + + void disable_auto_export() { auto_export_enabled_.store(false, std::memory_order_release); } + + [[nodiscard]] bool is_auto_export_enabled() const + { + return auto_export_enabled_.load(std::memory_order_acquire); + } + + /// Call after registration to trigger auto-export if enabled. + void perform_auto_export() + { + if(!auto_export_enabled_.load(std::memory_order_acquire)) + return; + std::lock_guard lock(mutex_); + if(auto_export_on_register_) + { + static_cast(this)->export_json_to_file(auto_export_path_, auto_export_stats_); + } + } + + protected: + [[nodiscard]] const std::unordered_map& entries() const + { + return entries_; + } + + [[nodiscard]] std::unordered_map& entries_mut() { return entries_; } + + std::mutex& mutex() const { return mutex_; } + + private: + mutable std::mutex mutex_; + std::unordered_map entries_; + std::string name_ = "default"; + + std::atomic auto_export_enabled_{false}; + bool auto_export_on_register_ = true; + bool auto_export_stats_ = true; + std::string auto_export_path_; +}; + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/dispatcher.hpp b/dispatcher/include/ck_tile/dispatcher/dispatcher.hpp index 6d3f548138..d266d693da 100644 --- a/dispatcher/include/ck_tile/dispatcher/dispatcher.hpp +++ b/dispatcher/include/ck_tile/dispatcher/dispatcher.hpp @@ -23,6 +23,7 @@ #pragma once +#include "ck_tile/dispatcher/dispatcher_error.hpp" #include "ck_tile/dispatcher/kernel_instance.hpp" #include "ck_tile/dispatcher/problem.hpp" #include "ck_tile/dispatcher/registry.hpp" @@ -52,7 +53,11 @@ class Dispatcher /// Constructor /// @param registry Registry instance to use (default: global singleton) - explicit Dispatcher(Registry* registry = nullptr); + /// @param gfx_arch Target GPU architecture (e.g. "gfx950") + explicit Dispatcher(Registry* registry = nullptr, const std::string& gfx_arch = ""); + + void set_arch(const std::string& arch) { gfx_arch_ = arch; } + [[nodiscard]] const std::string& arch() const { return gfx_arch_; } /// Register a heuristic function for kernel selection /// @param heuristic Function that maps problems to ranked kernel identifiers @@ -74,7 +79,7 @@ class Dispatcher /// @param problem Problem configuration /// @param stream HIP stream for kernel launch (nullptr = default stream) /// @return Kernel execution time in milliseconds - /// @throws std::runtime_error if no suitable kernel found + /// @throws NoKernelFound if no suitable kernel found [[nodiscard]] float run(const void* a_ptr, const void* b_ptr, void* c_ptr, @@ -89,7 +94,7 @@ class Dispatcher /// @param problem Problem configuration /// @param stream HIP stream for kernel launch (nullptr = default stream) /// @return Kernel execution time in milliseconds - /// @throws std::runtime_error if no suitable kernel found + /// @throws NoKernelFound if no suitable kernel found [[nodiscard]] float run_fused(const void* a_ptr, const void* b_ptr, void* c_ptr, @@ -106,7 +111,8 @@ class Dispatcher /// @param problem Problem configuration /// @param stream HIP stream for kernel launch (nullptr = default stream) /// @return Kernel execution time in milliseconds - /// @throws std::runtime_error if kernel not found or doesn't support problem + /// @throws NoKernelFound if the kernel identifier is not registered + /// @throws UnsupportedProblem if the selected kernel does not support the problem [[nodiscard]] float run_explicit(const std::string& kernel_id, const void* a_ptr, const void* b_ptr, @@ -130,10 +136,18 @@ class Dispatcher const Problem& problem, float tolerance = 1e-3f) const; + /// Enable or disable GPU benchmarking (timing) on all kernels. + /// When disabled, kernels execute once with no timing overhead + /// (one-shot mode for production plugins). + void set_benchmarking(bool enable) { benchmarking_ = enable; } + [[nodiscard]] bool benchmarking_enabled() const { return benchmarking_; } + private: Registry* registry_; HeuristicFunction heuristic_; SelectionStrategy strategy_; + std::string gfx_arch_; + bool benchmarking_ = true; /// Select kernel using first-fit strategy [[nodiscard]] KernelInstancePtr select_first_fit(const Problem& problem) const; diff --git a/dispatcher/include/ck_tile/dispatcher/dispatcher_error.hpp b/dispatcher/include/ck_tile/dispatcher/dispatcher_error.hpp new file mode 100644 index 0000000000..98b079f8d9 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/dispatcher_error.hpp @@ -0,0 +1,28 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +namespace ck_tile { +namespace dispatcher { + +struct DispatcherError : std::runtime_error +{ + using std::runtime_error::runtime_error; +}; + +struct NoKernelFound : DispatcherError +{ + using DispatcherError::DispatcherError; +}; + +struct UnsupportedProblem : DispatcherError +{ + using DispatcherError::DispatcherError; +}; + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/dispatcher_log.hpp b/dispatcher/include/ck_tile/dispatcher/dispatcher_log.hpp new file mode 100644 index 0000000000..6a39766649 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/dispatcher_log.hpp @@ -0,0 +1,55 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { + +/// Log levels for dispatcher transparency: +/// 0 = silent (default) +/// 1 = print selected kernel name +/// 2 = print all candidates considered and acceptance/rejection reasons +inline int get_log_level() +{ + static int level = []() { + const char* env = std::getenv("CK_DISPATCHER_LOG_LEVEL"); + return env ? std::atoi(env) : 0; + }(); + return level; +} + +inline void log_kernel_selected(const std::string& kernel_name, const std::string& problem_desc) +{ + if(get_log_level() >= 1) + { + std::cerr << "[CK Dispatcher] Selected kernel: " << kernel_name << " for " << problem_desc + << std::endl; + } +} + +inline void +log_kernel_candidate(const std::string& kernel_name, bool accepted, const std::string& reason) +{ + if(get_log_level() >= 2) + { + std::cerr << "[CK Dispatcher] Candidate: " << kernel_name << " -> " + << (accepted ? "ACCEPTED" : "REJECTED") + << (reason.empty() ? "" : " (" + reason + ")") << std::endl; + } +} + +inline void log_no_kernel_found(const std::string& problem_desc) +{ + if(get_log_level() >= 1) + { + std::cerr << "[CK Dispatcher] No kernel found for " << problem_desc << std::endl; + } +} + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/grouped_conv_config.hpp b/dispatcher/include/ck_tile/dispatcher/grouped_conv_config.hpp new file mode 100644 index 0000000000..91b7b3ad74 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/grouped_conv_config.hpp @@ -0,0 +1,588 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * @file grouped_conv_config.hpp + * @brief CK Tile Grouped Convolution Configuration with Builder-style naming + * + * This adopts the Signature/Algorithm/Arch pattern from: + * experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp + * + * Structure: + * - Signature: WHAT operation (types, layouts, direction, element ops) + * - Algorithm: HOW it's computed (tiles, warps, pipeline, scheduler, padding) + * - Arch: Target GPU architecture + */ + +#pragma once + +// Use common kernel_key types for DataType, Pipeline, etc. +#include "ck_tile/dispatcher/kernel_key.hpp" + +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { + +// DataType, Pipeline, Scheduler, Epilogue are defined in kernel_key.hpp +// No need to redefine them here + +// ============================================================================= +// Data Type Enum (matching CK Tile numeric types) +// ============================================================================= + +enum class ConvDataType +{ + // Standard floating point + FP32, // float + FP64, // double + FP16, // half_t + BF16, // bf16_t + + // 8-bit float variants (FP8/BF8) + FP8, // fp8_t (E4M3) + BF8, // bf8_t (E5M2) + FP8_E4M3, // Explicit E4M3 format + FP8_E5M2, // Explicit E5M2 format + + // Integer types + INT8, // int8_t + UINT8, // uint8_t + INT32, // int32_t (accumulator) + + // 4-bit types (gfx950+ only) + FP4, // MXFP4 + INT4 // pk_int4_t +}; + +// ============================================================================= +// Direction and Layout Enums +// ============================================================================= + +enum class GroupedConvDirection +{ + FORWARD, + BACKWARD_DATA, + BACKWARD_WEIGHT +}; + +enum class ConvLayout2D +{ + GNHWC_GKYXC_GNHWK, // NHWC-style + NHWGC_GKYXC_NHWGK, + NGCHW_GKYXC_NGKHW, // NCHW-style + NGCHW_GKCYX_NGKHW +}; + +enum class ConvLayout3D +{ + GNDHWC_GKZYXC_GNDHWK, + NDHWGC_GKZYXC_NDHWGK, + NGCDHW_GKZYXC_NGKDHW, + NGCDHW_GKCZYX_NGKDHW +}; + +// ============================================================================= +// Element-wise Operations +// ============================================================================= + +enum class ElementwiseOp +{ + PASS_THROUGH, + BIAS, + BIAS_CLAMP, + SCALE, + BILINEAR, + RELU, + GELU, + SIGMOID, + TANH +}; + +// ============================================================================= +// Grouped Convolution Specialization +// ============================================================================= + +enum class ConvSpecialization +{ + DEFAULT, + FILTER_1X1_PAD0, + FILTER_1X1_STRIDE1_PAD0, + FILTER_3X3, + FILTER_5X5, + FILTER_7X7 +}; + +// ============================================================================= +// Memory Operation Types (for accumulator operations) +// ============================================================================= + +enum class MemoryOperation +{ + SET, // Direct write (=) + ATOMIC_ADD, // Atomic addition (+=) + ATOMIC_MAX, // Atomic max + ADD // Non-atomic addition +}; + +// ============================================================================= +// Epilogue Types +// ============================================================================= + +enum class EpilogueType +{ + CSHUFFLE, // C-shuffle epilogue + DEFAULT_2D, // Default 2D epilogue + DEFAULT_GEMM_2D, // Default GEMM 2D epilogue + DIRECT_STORE, // Direct store without shuffle + BIAS_ADD, // Add bias + BIAS_ADD_RELU, // Add bias + ReLU + BIAS_ADD_GELU // Add bias + GELU +}; + +// ============================================================================= +// Algorithm Enums (matching builder/types.hpp and CK Tile pipelines) +// ============================================================================= + +enum class PipelineVersion +{ + V1, // Basic pipeline V1 + V2, // Basic pipeline V2 + V3, // Compute V3 (intrawave only) + V4, // Compute V4 (double buffer, ping-pong LDS) + V5, // Compute V5 (wave groups) + V6, // Compute V6 (newest) + MEMORY, // Memory pipeline + COMPUTE_ASYNC, // Compute with async copy + PRESHUFFLE_V2 // Preshuffle V2 pipeline +}; + +enum class PipelineScheduler +{ + DEFAULT, + INTRAWAVE, + INTERWAVE +}; + +enum class GemmPadding +{ + DEFAULT, + NO_PADDING, // No padding + M_PADDING, + N_PADDING, + K_PADDING, + MN_PADDING, + MK_PADDING, + NK_PADDING, + MNK_PADDING +}; + +// ============================================================================= +// Signature Info (WHAT operation) +// ============================================================================= + +struct GroupedConvSignatureInfo +{ + int spatial_dim = 2; // 1, 2, or 3 + GroupedConvDirection direction = GroupedConvDirection::FORWARD; + std::string in_type = "fp16"; + std::string wei_type = "fp16"; + std::string out_type = "fp16"; + std::string acc_type = "fp32"; + std::string workspace_type = "fp32"; // For two-stage algorithms + std::string bias_type = "fp16"; // For bias epilogue + ElementwiseOp in_element_op = ElementwiseOp::PASS_THROUGH; + ElementwiseOp wei_element_op = ElementwiseOp::PASS_THROUGH; + ElementwiseOp out_element_op = ElementwiseOp::PASS_THROUGH; + ConvSpecialization conv_spec = ConvSpecialization::DEFAULT; + int num_groups = 1; + + // String helpers + static const char* direction_str(GroupedConvDirection dir) + { + switch(dir) + { + case GroupedConvDirection::FORWARD: return "fwd"; + case GroupedConvDirection::BACKWARD_DATA: return "bwd_data"; + case GroupedConvDirection::BACKWARD_WEIGHT: return "bwd_weight"; + default: return "unknown"; + } + } + + static const char* datatype_str(ConvDataType dt) + { + switch(dt) + { + case ConvDataType::FP32: return "fp32"; + case ConvDataType::FP64: return "fp64"; + case ConvDataType::FP16: return "fp16"; + case ConvDataType::BF16: return "bf16"; + case ConvDataType::FP8: return "fp8"; + case ConvDataType::BF8: return "bf8"; + case ConvDataType::FP8_E4M3: return "fp8_e4m3"; + case ConvDataType::FP8_E5M2: return "fp8_e5m2"; + case ConvDataType::INT8: return "int8"; + case ConvDataType::UINT8: return "uint8"; + case ConvDataType::INT32: return "int32"; + case ConvDataType::FP4: return "fp4"; + case ConvDataType::INT4: return "int4"; + default: return "unknown"; + } + } +}; + +// ============================================================================= +// Algorithm Info (HOW it's computed) +// ============================================================================= + +struct DataTileInfo +{ + int m = 128; // M tile (output spatial * N) + int n = 128; // N tile (K output channels) + int k = 64; // K tile (C input channels) +}; + +struct WarpGemmParams +{ + int gemm_m = 16; // MFMA M dimension (MPerXDL) + int gemm_n = 16; // MFMA N dimension (NPerXDL) + int m_iter = 2; // M iterations per warp (MXdlPerWave) + int n_iter = 2; // N iterations per warp (NXdlPerWave) +}; + +struct BlockWarpConfig +{ + int m_warp = 2; // Warps along M + int n_warp = 2; // Warps along N + int k_warp = 1; // Warps along K + int m_warp_tile = 32; // Warp tile M + int n_warp_tile = 32; // Warp tile N + int k_warp_tile = 16; // Warp tile K +}; + +struct VectorSizeInfo +{ + int a = 4; // Input vector size + int b = 8; // Weight vector size + int c = 8; // Output vector size +}; + +struct GroupedConvAlgorithmInfo +{ + DataTileInfo tile; + BlockWarpConfig warp; + VectorSizeInfo vector_size; + + PipelineVersion pipeline = PipelineVersion::V4; + PipelineScheduler scheduler = PipelineScheduler::INTRAWAVE; + GemmPadding padding = GemmPadding::MNK_PADDING; + MemoryOperation memory_op = MemoryOperation::SET; + EpilogueType epilogue = EpilogueType::CSHUFFLE; + + int thread_block_size = 256; + bool double_smem_buffer = false; + int num_wave_groups = 1; + int block_per_cu = 1; + int num_groups_to_merge = 1; + + // Pipeline string + static const char* pipeline_str(PipelineVersion pv) + { + switch(pv) + { + case PipelineVersion::V1: return "v1"; + case PipelineVersion::V2: return "v2"; + case PipelineVersion::V3: return "compv3"; + case PipelineVersion::V4: return "compv4"; + case PipelineVersion::V5: return "compv5"; + case PipelineVersion::V6: return "compv6"; + case PipelineVersion::MEMORY: return "mem"; + case PipelineVersion::COMPUTE_ASYNC: return "comp_async"; + case PipelineVersion::PRESHUFFLE_V2: return "preshuffle_v2"; + default: return "unknown"; + } + } + + static const char* scheduler_str(PipelineScheduler ps) + { + switch(ps) + { + case PipelineScheduler::DEFAULT: return "default"; + case PipelineScheduler::INTRAWAVE: return "intrawave"; + case PipelineScheduler::INTERWAVE: return "interwave"; + default: return "unknown"; + } + } + + static const char* memory_op_str(MemoryOperation mo) + { + switch(mo) + { + case MemoryOperation::SET: return "set"; + case MemoryOperation::ATOMIC_ADD: return "atomic_add"; + case MemoryOperation::ATOMIC_MAX: return "atomic_max"; + case MemoryOperation::ADD: return "add"; + default: return "unknown"; + } + } + + static const char* epilogue_str(EpilogueType et) + { + switch(et) + { + case EpilogueType::CSHUFFLE: return "cshuffle"; + case EpilogueType::DEFAULT_2D: return "default_2d"; + case EpilogueType::DEFAULT_GEMM_2D: return "default_gemm_2d"; + case EpilogueType::DIRECT_STORE: return "direct_store"; + case EpilogueType::BIAS_ADD: return "bias_add"; + case EpilogueType::BIAS_ADD_RELU: return "bias_add_relu"; + case EpilogueType::BIAS_ADD_GELU: return "bias_add_gelu"; + default: return "unknown"; + } + } +}; + +// ============================================================================= +// Arch Info (Target GPU) +// ============================================================================= + +struct ArchInfo +{ + std::string name = "gfx942"; // MI300X default + int max_waves_per_cu = 8; + int lds_size_kb = 64; + int sgpr_count = 108; + int vgpr_count = 512; + + bool supports_mfma_fp16() const { return name.find("gfx9") != std::string::npos; } + bool supports_wmma() const { return name.find("gfx11") != std::string::npos; } +}; + +// ============================================================================= +// Full Grouped Conv Config (combines Signature + Algorithm + Arch) +// ============================================================================= + +struct GroupedConvConfig +{ + GroupedConvSignatureInfo signature; + GroupedConvAlgorithmInfo algorithm; + ArchInfo arch; + + // Generate unique kernel name + std::string name() const + { + std::ostringstream oss; + oss << "grouped_conv_" << GroupedConvSignatureInfo::direction_str(signature.direction) + << "_" << signature.in_type << "_" << signature.spatial_dim << "d" << "_" + << GroupedConvAlgorithmInfo::pipeline_str(algorithm.pipeline) << "_" << algorithm.tile.m + << "x" << algorithm.tile.n << "x" << algorithm.tile.k; + return oss.str(); + } + + // Brief description + std::string brief() const + { + std::ostringstream oss; + oss << signature.spatial_dim << "D " + << GroupedConvSignatureInfo::direction_str(signature.direction) + << " Grouped Convolution (" << signature.in_type << ")"; + return oss.str(); + } + + // Detailed description (tree-like) + std::string detailed() const + { + std::ostringstream oss; + oss << signature.spatial_dim << "D " + << GroupedConvSignatureInfo::direction_str(signature.direction) + << " Grouped Convolution Kernel\n"; + + oss << " Signature:\n"; + oss << " Data Type: " << signature.in_type << "\n"; + oss << " Accumulator: " << signature.acc_type << "\n"; + oss << " Groups: " << signature.num_groups << "\n"; + + oss << " Algorithm:\n"; + oss << " Thread Block Size: " << algorithm.thread_block_size << "\n"; + oss << " Data Tile: " << algorithm.tile.m << "x" << algorithm.tile.n << "x" + << algorithm.tile.k << "\n"; + oss << " Warp Config: " << algorithm.warp.m_warp << "x" << algorithm.warp.n_warp << "x" + << algorithm.warp.k_warp << "\n"; + oss << " Warp Tile: " << algorithm.warp.m_warp_tile << "x" << algorithm.warp.n_warp_tile + << "x" << algorithm.warp.k_warp_tile << "\n"; + oss << " Pipeline: " << GroupedConvAlgorithmInfo::pipeline_str(algorithm.pipeline) + << "\n"; + oss << " Scheduler: " << GroupedConvAlgorithmInfo::scheduler_str(algorithm.scheduler) + << "\n"; + + oss << " Arch:\n"; + oss << " Target: " << arch.name << "\n"; + + return oss.str(); + } +}; + +// ============================================================================= +// Predefined Configs +// ============================================================================= + +namespace configs { + +// Memory-bound config +template +struct Memory : public GroupedConvConfig +{ + Memory() + { + algorithm.tile = {128, 32, 128 / (int)sizeof(PrecType)}; + algorithm.warp = {4, 1, 1, 32, 32, 16}; + algorithm.pipeline = PipelineVersion::MEMORY; + algorithm.double_smem_buffer = false; + } +}; + +// Compute V3 - Small +template +struct CompV3_Small : public GroupedConvConfig +{ + CompV3_Small() + { + algorithm.tile = {16, 64, 64}; + algorithm.warp = {1, 4, 1, 16, 16, 32}; + algorithm.pipeline = PipelineVersion::V3; + } +}; + +// Compute V3 - Medium +template +struct CompV3_Medium : public GroupedConvConfig +{ + CompV3_Medium() + { + algorithm.tile = {128, 128, 128 / (int)sizeof(PrecType)}; + algorithm.warp = {2, 2, 1, 16, 16, 32}; + algorithm.pipeline = PipelineVersion::V3; + algorithm.block_per_cu = 2; + } +}; + +// Compute V3 - Large +template +struct CompV3_Large : public GroupedConvConfig +{ + CompV3_Large() + { + algorithm.tile = {256, 256, 128 / (int)sizeof(PrecType)}; + algorithm.warp = {2, 2, 1, 32, 32, 16}; + algorithm.pipeline = PipelineVersion::V3; + } +}; + +// Compute V4 - Double buffered +template +struct CompV4 : public GroupedConvConfig +{ + CompV4() + { + algorithm.tile = {256, 256, 64 / (int)sizeof(PrecType)}; + algorithm.warp = {2, 2, 1, 32, 32, 16}; + algorithm.pipeline = PipelineVersion::V4; + algorithm.double_smem_buffer = true; + } +}; + +// Compute V5 - Wave groups +template +struct CompV5 : public GroupedConvConfig +{ + CompV5() + { + algorithm.tile = {128, 128, 64 / (int)sizeof(PrecType)}; + algorithm.warp = {1, 1, 2, 32, 32, 16}; + algorithm.pipeline = PipelineVersion::V5; + algorithm.num_wave_groups = 2; + } +}; + +// WMMA config for gfx11xx +template +struct WMMA : public GroupedConvConfig +{ + WMMA() + { + algorithm.tile = {128, 128, 64 / (int)sizeof(PrecType)}; + algorithm.warp = {4, 2, 1, 16, 16, 16}; + algorithm.pipeline = PipelineVersion::V3; + algorithm.block_per_cu = 2; + arch.name = "gfx1100"; + } +}; + +// Merged groups config +template +struct CompV3_MergedGroups : public GroupedConvConfig +{ + CompV3_MergedGroups() + { + algorithm.tile = {16, 32, 32}; + algorithm.warp = {1, 2, 1, 16, 16, 32}; + algorithm.vector_size = {4, 8, 8}; + algorithm.pipeline = PipelineVersion::V3; + algorithm.num_groups_to_merge = 2; + } +}; + +} // namespace configs + +// ============================================================================= +// DataType Traits (compile-time type info for CK Tile types) +// ============================================================================= + +template +struct DataTypeTraits; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp32"; + static constexpr int size_bytes = 4; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp64"; + static constexpr int size_bytes = 8; +}; + +// Forward declare CK Tile types for traits +// Note: actual ck_tile types are defined in ck_tile/core/numeric/ +// These traits allow working with type names at compile time + +// ============================================================================= +// ConvTypeConfig (input/weight/acc/output type combinations) +// ============================================================================= + +template +struct ConvTypeConfig +{ + using input_type = InDataType; + using weight_type = WeiDataType; + using output_type = OutDataType; + using accumulator_type = AccDataType; +}; + +// Common type configurations as type aliases +// FP16 -> FP32 accumulator -> FP16 output (most common) +// BF16 -> FP32 accumulator -> BF16 output +// FP8 -> FP32 accumulator -> FP8 output +// INT8 -> INT32 accumulator -> INT8 output + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/grouped_conv_kernel_decl.hpp b/dispatcher/include/ck_tile/dispatcher/grouped_conv_kernel_decl.hpp new file mode 100644 index 0000000000..8ddfe445ff --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/grouped_conv_kernel_decl.hpp @@ -0,0 +1,537 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * @file grouped_conv_kernel_decl.hpp + * @brief Declarative grouped convolution kernel specification + * + * USAGE: + * ====== + * + * // Named kernel sets for grouped convolution + * DECL_GROUPED_CONV_KERNEL_SET(gconv_fwd, + * .add("fp16", "nhwc", "forward", 128, 128, 32) + * .add("fp16", "nhwc", "forward", 256, 256, 64) + * ); + * + * // Access at runtime + * auto& set = GroupedConvKernelSetRegistry::instance().get("gconv_fwd"); + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { +namespace grouped_conv_decl { + +// ============================================================================= +// Wildcard constants +// ============================================================================= + +constexpr const char* ANY = "*"; +constexpr int ANY_INT = -1; + +// ============================================================================= +// GroupedConvSignature - WHAT operation +// ============================================================================= + +class GroupedConvSignature +{ + public: + std::string dtype_in_ = "fp16"; // Input data type + std::string dtype_wei_ = "fp16"; // Weight data type + std::string dtype_out_ = "fp16"; // Output data type + std::string dtype_acc_ = "fp32"; // Accumulator type + std::string dtype_workspace_ = "fp32"; // Workspace type (two-stage algorithms) + std::string dtype_bias_ = "fp16"; // Bias type (bias epilogue) + std::string layout_ = "nhwc"; // Data layout: nhwc, nchw + std::string conv_op_ = "forward"; // forward, bwd_data, bwd_weight + int num_dims_ = 2; // Spatial dimensions: 1, 2, or 3 + int groups_ = 1; // Group grouped convolution + std::string specialization_ = "default"; // Filter specialization + + GroupedConvSignature& dtype(const std::string& in, + const std::string& wei, + const std::string& out, + const std::string& acc = "fp32") + { + dtype_in_ = in; + dtype_wei_ = wei; + dtype_out_ = out; + dtype_acc_ = acc; + return *this; + } + + GroupedConvSignature& dtype(const std::string& all) + { + dtype_in_ = dtype_wei_ = dtype_out_ = dtype_bias_ = all; + dtype_acc_ = dtype_workspace_ = "fp32"; + return *this; + } + + GroupedConvSignature& dtype_workspace(const std::string& ws) + { + dtype_workspace_ = ws; + return *this; + } + + GroupedConvSignature& dtype_bias(const std::string& b) + { + dtype_bias_ = b; + return *this; + } + + GroupedConvSignature& layout(const std::string& l) + { + layout_ = l; + return *this; + } + GroupedConvSignature& conv_type(const std::string& op) + { + conv_op_ = op; + return *this; + } + GroupedConvSignature& dims(int d) + { + num_dims_ = d; + return *this; + } + GroupedConvSignature& groups(int g) + { + groups_ = g; + return *this; + } + GroupedConvSignature& spec(const std::string& s) + { + specialization_ = s; + return *this; + } + + std::string op_str() const + { + if(conv_op_ == "forward") + return "fwd"; + if(conv_op_ == "bwd_data") + return "bwd_data"; + if(conv_op_ == "bwd_weight") + return "bwd_weight"; + return conv_op_; + } +}; + +// ============================================================================= +// GroupedConvAlgorithm - HOW it's implemented +// ============================================================================= + +class GroupedConvAlgorithm +{ + public: + // Tile shape (M, N, K per tile - M=spatial*N, N=K_out, K=C_in) + int tile_m_ = 1; // Tile M (output spatial * batch) + int tile_n_ = 128; // Tile N (output channels K) + int tile_k_ = 128; // Tile K (input channels C) + + // Output spatial tile + int tile_ho_ = 1; + int tile_wo_ = 16; + + // Wave/warp shape + int wave_m_ = ANY_INT; + int wave_n_ = ANY_INT; + int wave_k_ = 1; + int warp_m_ = ANY_INT; + int warp_n_ = ANY_INT; + int warp_k_ = 16; + + // Vector sizes + int vector_a_ = 4; // Input vector size + int vector_b_ = 8; // Weight vector size + int vector_c_ = 8; // Output vector size + + // Pipeline configuration + std::string pipeline_ = "compv4"; + std::string scheduler_ = "intrawave"; + std::string epilogue_ = "cshuffle"; + std::string memory_op_ = "set"; // Memory operation: set, atomic_add, atomic_max, add + + // Occupancy/performance hints + int block_size_ = 256; + int block_per_cu_ = 1; + int num_wave_groups_ = 1; + int num_groups_to_merge_ = 1; + bool double_smem_buffer_ = false; + + // Padding -- always enabled for convolution (MNK padding assumed) + static constexpr bool pad_m_ = true; + static constexpr bool pad_n_ = true; + static constexpr bool pad_k_ = true; + + // Tile setter (M, N, K) + GroupedConvAlgorithm& tile(int m, int n, int k) + { + tile_m_ = m; + tile_n_ = n; + tile_k_ = k; + return *this; + } + + GroupedConvAlgorithm& tile_output(int ho, int wo) + { + tile_ho_ = ho; + tile_wo_ = wo; + return *this; + } + + GroupedConvAlgorithm& wave(int m, int n, int k = 1) + { + wave_m_ = m; + wave_n_ = n; + wave_k_ = k; + return *this; + } + + GroupedConvAlgorithm& warp(int m, int n, int k = 16) + { + warp_m_ = m; + warp_n_ = n; + warp_k_ = k; + return *this; + } + + GroupedConvAlgorithm& vector_sizes(int a, int b, int c) + { + vector_a_ = a; + vector_b_ = b; + vector_c_ = c; + return *this; + } + + GroupedConvAlgorithm& pipeline(const std::string& p) + { + pipeline_ = p; + return *this; + } + GroupedConvAlgorithm& scheduler(const std::string& s) + { + scheduler_ = s; + return *this; + } + GroupedConvAlgorithm& epilogue(const std::string& e) + { + epilogue_ = e; + return *this; + } + GroupedConvAlgorithm& memory_op(const std::string& m) + { + memory_op_ = m; + return *this; + } + + // Occupancy setters + GroupedConvAlgorithm& block_per_cu(int b) + { + block_per_cu_ = b; + return *this; + } + GroupedConvAlgorithm& num_wave_groups(int n) + { + num_wave_groups_ = n; + return *this; + } + GroupedConvAlgorithm& num_groups_to_merge(int n) + { + num_groups_to_merge_ = n; + return *this; + } + GroupedConvAlgorithm& double_smem_buffer(bool d) + { + double_smem_buffer_ = d; + return *this; + } + + bool needs_expansion() const + { + return wave_m_ == ANY_INT || warp_m_ == ANY_INT || pipeline_ == "*" || scheduler_ == "*"; + } + + /// Check if specific parameter needs expansion + bool needs_wave_expansion() const { return wave_m_ == ANY_INT || wave_n_ == ANY_INT; } + bool needs_warp_expansion() const { return warp_m_ == ANY_INT || warp_n_ == ANY_INT; } + bool needs_pipeline_expansion() const { return pipeline_ == "*"; } + bool needs_scheduler_expansion() const { return scheduler_ == "*"; } + + /// Auto-fill with defaults (for single kernel generation) + void auto_fill() + { + if(wave_m_ == ANY_INT) + wave_m_ = 2; + if(wave_n_ == ANY_INT) + wave_n_ = 2; + if(warp_m_ == ANY_INT) + warp_m_ = 32; + if(warp_n_ == ANY_INT) + warp_n_ = 32; + if(pipeline_ == "*") + pipeline_ = "compv4"; + if(scheduler_ == "*") + scheduler_ = "intrawave"; + } + + /// Get all valid wave configurations for arch + static std::vector> valid_wave_configs(const std::string& arch) + { + // Match arch_specs_generated.py WARP_SUPPORTED_COMBINATIONS + if(arch == "gfx942" || arch == "gfx90a" || arch == "gfx950") + { + return {{1, 4, 1}, {2, 2, 1}, {4, 1, 1}}; + } + return {{2, 2, 1}}; // Default + } + + /// Get all valid warp tile configurations + static std::vector> valid_warp_configs(const std::string& arch, + const std::string& dtype) + { + // Match arch_specs_generated.py WARP_TILE_SUPPORTED_COMBINATIONS + if(arch == "gfx942" && (dtype == "fp16" || dtype == "bf16")) + { + return {{16, 16, 16}, {32, 32, 16}}; + } + return {{32, 32, 16}}; // Default + } + + /// Get all valid pipeline/scheduler combinations for forward conv. + /// Backward operations (bwd_data/bwd_weight) only support compv3 and mem + /// due to transpose_tile2d and get_length constraints in CK Tile. + static std::vector> valid_trait_configs() + { + return { + {"compv3", "intrawave"}, + {"compv4", "intrawave"}, + {"compv5", "intrawave"}, + {"mem", "intrawave"}, + {"mem", "interwave"}, + }; + } +}; + +// ============================================================================= +// GroupedConvKernelDecl +// ============================================================================= + +struct GroupedConvKernelDecl +{ + GroupedConvSignature signature; + GroupedConvAlgorithm algorithm; + std::string arch = "gfx942"; + + GroupedConvKernelDecl() = default; + + GroupedConvKernelDecl(const GroupedConvSignature& sig, + const GroupedConvAlgorithm& algo, + const std::string& a = "gfx942") + : signature(sig), algorithm(algo), arch(a) + { + } + + std::string name() const + { + std::ostringstream oss; + // Generate full kernel name similar to GEMM: + // grouped_conv____d______ + oss << "grouped_conv_" << signature.op_str() << "_" << signature.dtype_in_ << "_" + << signature.layout_ << "_" << signature.num_dims_ << "d" << "_" << algorithm.pipeline_ + << "_" << algorithm.epilogue_ << "_" << algorithm.scheduler_ << "_" << algorithm.tile_m_ + << "x" << algorithm.tile_n_ << "x" << algorithm.tile_k_ << "_" << algorithm.wave_m_ + << "x" << algorithm.wave_n_ << "x" << algorithm.wave_k_ << "_" << algorithm.warp_m_ + << "x" << algorithm.warp_n_ << "x" << algorithm.warp_k_; + return oss.str(); + } + + bool has_wildcards() const { return algorithm.needs_expansion() || arch == "*"; } +}; + +// ============================================================================= +// GroupedConvKernelSet +// ============================================================================= + +class GroupedConvKernelSet +{ + public: + GroupedConvKernelSet() = default; + + GroupedConvKernelSet& add(const GroupedConvSignature& sig, + const GroupedConvAlgorithm& algo, + const std::string& arch = "gfx942") + { + decls_.emplace_back(sig, algo, arch); + return *this; + } + + // Simple add: dtype, layout, conv_type, tile_k, tile_c + GroupedConvKernelSet& add(const std::string& dtype, + const std::string& layout, + const std::string& conv_type, + int tile_k, + int tile_c, + const std::string& arch = "gfx942") + { + GroupedConvSignature sig; + sig.dtype(dtype).layout(layout).conv_type(conv_type); + GroupedConvAlgorithm algo; + algo.tile(1, tile_k, tile_c); + decls_.emplace_back(sig, algo, arch); + return *this; + } + + GroupedConvKernelSet& merge(const GroupedConvKernelSet& other) + { + decls_.insert(decls_.end(), other.decls_.begin(), other.decls_.end()); + return *this; + } + + const std::vector& declarations() const { return decls_; } + size_t size() const { return decls_.size(); } + + void print(std::ostream& os = std::cout) const + { + os << "GroupedConvKernelSet (" << size() << " declarations):\n"; + for(const auto& d : decls_) + { + os << " - " << d.name(); + if(d.algorithm.needs_expansion()) + os << " [expands]"; + os << "\n"; + } + } + + GroupedConvKernelSet& tag(const std::string& t) + { + tag_ = t; + return *this; + } + std::string tag() const { return tag_; } + + private: + std::vector decls_; + std::string tag_; +}; + +// ============================================================================= +// GroupedConvKernelSetRegistry +// ============================================================================= + +class GroupedConvKernelSetRegistry +{ + public: + static GroupedConvKernelSetRegistry& instance() + { + static GroupedConvKernelSetRegistry reg; + return reg; + } + + void add(const std::string& name, const GroupedConvKernelSet& set) + { + sets_[name] = set; + if(std::find(order_.begin(), order_.end(), name) == order_.end()) + { + order_.push_back(name); + } + } + + // Alias for add() for consistency with GEMM API + void register_set(const std::string& name, const GroupedConvKernelSet& set) { add(name, set); } + + const GroupedConvKernelSet& get(const std::string& name) const + { + static GroupedConvKernelSet empty; + auto it = sets_.find(name); + return it != sets_.end() ? it->second : empty; + } + + bool has(const std::string& name) const { return sets_.find(name) != sets_.end(); } + + std::vector names() const { return order_; } + size_t size() const { return sets_.size(); } + + void clear() + { + sets_.clear(); + order_.clear(); + } + + void print() const + { + std::cout << "Grouped Conv Kernel Sets (" << size() << "):\n"; + for(const auto& name : order_) + { + const auto& set = sets_.at(name); + std::cout << " " << name << ": " << set.size() << " declarations\n"; + } + } + + private: + GroupedConvKernelSetRegistry() = default; + std::unordered_map sets_; + std::vector order_; +}; + +// ============================================================================= +// Static Registrar +// ============================================================================= + +struct GroupedConvKernelSetRegistrar +{ + GroupedConvKernelSetRegistrar(const std::string& name, const GroupedConvKernelSet& set) + { + GroupedConvKernelSetRegistry::instance().add(name, set); + } +}; + +} // namespace grouped_conv_decl + +// Convenience aliases +using GroupedConvSignature = grouped_conv_decl::GroupedConvSignature; +using GroupedConvAlgorithm = grouped_conv_decl::GroupedConvAlgorithm; +using GroupedConvKernelDecl = grouped_conv_decl::GroupedConvKernelDecl; +using GroupedConvKernelSet = grouped_conv_decl::GroupedConvKernelSet; +using GroupedConvKernelSetRegistry = grouped_conv_decl::GroupedConvKernelSetRegistry; + +} // namespace dispatcher +} // namespace ck_tile + +// ============================================================================= +// Declaration Macros +// ============================================================================= + +#define CK_GROUPED_CONV_DECL_CAT_(a, b) CK_GROUPED_CONV_DECL_CAT_IMPL_(a, b) +#define CK_GROUPED_CONV_DECL_CAT_IMPL_(a, b) a##b + +// Note: __extension__ suppresses warnings about __COUNTER__ being a GCC/Clang extension +#define DECL_GROUPED_CONV_KERNEL_SET(name, ...) \ + __extension__ static ::ck_tile::dispatcher::grouped_conv_decl::GroupedConvKernelSetRegistrar \ + CK_GROUPED_CONV_DECL_CAT_(_gconv_kset_reg_, __COUNTER__)( \ + #name, \ + ::ck_tile::dispatcher::grouped_conv_decl::GroupedConvKernelSet() __VA_ARGS__.tag(#name)) + +#define DECL_GROUPED_CONV_KERNEL_ALL(dtype, layout) \ + __extension__ static ::ck_tile::dispatcher::grouped_conv_decl::GroupedConvKernelSetRegistrar \ + CK_GROUPED_CONV_DECL_CAT_(_gconv_kset_reg_, __COUNTER__)( \ + #dtype "_" #layout "_all", \ + ::ck_tile::dispatcher::grouped_conv_decl::GroupedConvKernelSet().add( \ + ::ck_tile::dispatcher::grouped_conv_decl::GroupedConvSignature().dtype(#dtype).layout( \ + #layout), \ + ::ck_tile::dispatcher::grouped_conv_decl::GroupedConvAlgorithm(), \ + "*")) + +#define GROUPED_CONV_KERNEL_SET(name) \ + ::ck_tile::dispatcher::grouped_conv_decl::GroupedConvKernelSet name +#define BEGIN_GROUPED_CONV_KERNEL_SET() \ + ::ck_tile::dispatcher::grouped_conv_decl::GroupedConvKernelSet() diff --git a/dispatcher/include/ck_tile/dispatcher/grouped_conv_problem.hpp b/dispatcher/include/ck_tile/dispatcher/grouped_conv_problem.hpp new file mode 100644 index 0000000000..5b58f37206 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/grouped_conv_problem.hpp @@ -0,0 +1,255 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * @file grouped_conv_problem.hpp + * @brief Grouped Convolution problem definition + */ + +#pragma once + +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { + +/** + * @brief Grouped Convolution operation type + */ +enum class GroupedConvOp +{ + Forward, // Y = Conv(X, W) + BackwardData, // dX = ConvBwdData(dY, W) + BackwardWeight // dW = ConvBwdWeight(X, dY) +}; + +/** + * @brief Grouped Convolution problem specification + */ +struct GroupedConvProblem +{ + // Batch and channels + std::int64_t N; // Batch size + std::int64_t C; // Input channels + std::int64_t K; // Output channels (filters) + std::int64_t G; // Number of groups (1 for standard conv) + + // Spatial dimensions (supports 1D, 2D, 3D) + std::array input_spatial; // {D, H, W} or {1, H, W} for 2D + std::array filter_spatial; // {Z, Y, X} or {1, Y, X} for 2D + std::array output_spatial; // {Do, Ho, Wo} or {1, Ho, Wo} for 2D + + // Convolution parameters + std::array stride; // Stride in each dimension + std::array padding; // Padding in each dimension + std::array dilation; // Dilation in each dimension + + // Operation type + GroupedConvOp op = GroupedConvOp::Forward; + + // Split-K for backward weight (k_batch parameter in CK Tile). + // Values > 1 split the reduction dimension across multiple thread blocks + // and use atomic accumulation. + int split_k = 1; + + // Default constructor for 2D convolution + GroupedConvProblem() + : N(1), + C(64), + K(64), + G(1), + input_spatial{1, 28, 28}, + filter_spatial{1, 3, 3}, + output_spatial{1, 26, 26}, + stride{1, 1, 1}, + padding{0, 0, 0}, + dilation{1, 1, 1}, + op(GroupedConvOp::Forward) + { + } + + // Constructor for 2D convolution + GroupedConvProblem(std::int64_t n, + std::int64_t c, + std::int64_t k, + std::int64_t hi, + std::int64_t wi, + std::int64_t y, + std::int64_t x, + std::int64_t stride_h = 1, + std::int64_t stride_w = 1, + std::int64_t pad_h = 0, + std::int64_t pad_w = 0, + std::int64_t dilation_h = 1, + std::int64_t dilation_w = 1) + : N(n), + C(c), + K(k), + G(1), + input_spatial{1, hi, wi}, + filter_spatial{1, y, x}, + stride{1, stride_h, stride_w}, + padding{0, pad_h, pad_w}, + dilation{1, dilation_h, dilation_w}, + op(GroupedConvOp::Forward) + { + compute_output_size(); + } + + /// Check if problem dimensions are valid + bool is_valid() const + { + return N > 0 && C > 0 && K > 0 && G > 0 && (C % G == 0) && (K % G == 0); + } + + /// Compute output spatial dimensions + void compute_output_size() + { + for(int i = 0; i < 3; ++i) + { + std::int64_t effective_filter = (filter_spatial[i] - 1) * dilation[i] + 1; + output_spatial[i] = + (input_spatial[i] + 2 * padding[i] - effective_filter) / stride[i] + 1; + } + } + + /// Get 2D height/width accessors + std::int64_t Hi() const { return input_spatial[1]; } + std::int64_t Wi() const { return input_spatial[2]; } + std::int64_t Ho() const { return output_spatial[1]; } + std::int64_t Wo() const { return output_spatial[2]; } + std::int64_t Y() const { return filter_spatial[1]; } // Filter height + std::int64_t X() const { return filter_spatial[2]; } // Filter width + + /// Get total FLOPs for this convolution + double get_flops() const + { + // Forward: 2 * N * K * Ho * Wo * C * Y * X / G + double spatial_out = 1.0; + double filter_size = 1.0; + for(int i = 0; i < 3; ++i) + { + spatial_out *= output_spatial[i]; + filter_size *= filter_spatial[i]; + } + return 2.0 * N * K * spatial_out * (C / G) * filter_size; + } + + /// Check if this is a depthwise convolution + bool is_depthwise() const { return G == C && G == K; } + + /// Check if this is a pointwise (1x1) convolution + bool is_pointwise() const + { + return filter_spatial[0] == 1 && filter_spatial[1] == 1 && filter_spatial[2] == 1; + } + + /// String representation + std::string to_string() const + { + std::string s = "GroupedConvProblem(N=" + std::to_string(N); + s += ", C=" + std::to_string(C) + ", K=" + std::to_string(K); + s += ", G=" + std::to_string(G); + s += ", Hi=" + std::to_string(Hi()) + ", Wi=" + std::to_string(Wi()); + s += ", Y=" + std::to_string(Y()) + ", X=" + std::to_string(X()); + s += ", Ho=" + std::to_string(Ho()) + ", Wo=" + std::to_string(Wo()); + s += ")"; + return s; + } +}; + +// ============================================================================= +// GroupedConvProblemBuilder +// ============================================================================= + +/// Builder pattern for Grouped Convolution problem configuration +class GroupedConvProblemBuilder +{ + public: + GroupedConvProblemBuilder() = default; + + GroupedConvProblemBuilder& batch(std::int64_t n) + { + problem_.N = n; + return *this; + } + + GroupedConvProblemBuilder& channels(std::int64_t c, std::int64_t k) + { + problem_.C = c; + problem_.K = k; + return *this; + } + + GroupedConvProblemBuilder& groups(std::int64_t g) + { + problem_.G = g; + return *this; + } + + GroupedConvProblemBuilder& input_size(std::int64_t h, std::int64_t w) + { + problem_.input_spatial[0] = 1; + problem_.input_spatial[1] = h; + problem_.input_spatial[2] = w; + return *this; + } + + GroupedConvProblemBuilder& filter_size(std::int64_t y, std::int64_t x) + { + problem_.filter_spatial[0] = 1; + problem_.filter_spatial[1] = y; + problem_.filter_spatial[2] = x; + return *this; + } + + GroupedConvProblemBuilder& stride(std::int64_t sh, std::int64_t sw) + { + problem_.stride[0] = 1; + problem_.stride[1] = sh; + problem_.stride[2] = sw; + return *this; + } + + GroupedConvProblemBuilder& padding(std::int64_t ph, std::int64_t pw) + { + problem_.padding[0] = 0; + problem_.padding[1] = ph; + problem_.padding[2] = pw; + return *this; + } + + GroupedConvProblemBuilder& dilation(std::int64_t dh, std::int64_t dw) + { + problem_.dilation[0] = 1; + problem_.dilation[1] = dh; + problem_.dilation[2] = dw; + return *this; + } + + GroupedConvProblemBuilder& operation(GroupedConvOp op) + { + problem_.op = op; + return *this; + } + + [[nodiscard]] GroupedConvProblem build() const + { + GroupedConvProblem p = problem_; + p.compute_output_size(); + if(!p.is_valid()) + { + throw std::invalid_argument("Invalid grouped convolution problem dimensions"); + } + return p; + } + + private: + GroupedConvProblem problem_; +}; + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/grouped_conv_registry.hpp b/dispatcher/include/ck_tile/dispatcher/grouped_conv_registry.hpp new file mode 100644 index 0000000000..42698a0bc8 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/grouped_conv_registry.hpp @@ -0,0 +1,614 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * @file grouped_conv_registry.hpp + * @brief Grouped Convolution kernel registry and dispatcher + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/dispatcher/base_registry.hpp" +#include "ck_tile/dispatcher/dispatcher_error.hpp" +#include "ck_tile/dispatcher/grouped_conv_problem.hpp" +#include "ck_tile/dispatcher/grouped_conv_kernel_decl.hpp" + +namespace ck_tile { +namespace dispatcher { + +// ============================================================================= +// Thread-local buffer context for GroupedConvDispatcher::run() +// The generated conv backend RunFn reads these to get buffer pointers. +// ============================================================================= + +struct ConvDispatchBuffers +{ + const void* input_ptr = nullptr; + const void* weight_ptr = nullptr; + void* output_ptr = nullptr; + int warmup = 3; + int repeat = 10; + bool benchmarking = true; + int split_k = 1; +}; + +inline thread_local ConvDispatchBuffers g_conv_dispatch_buffers; + +// ============================================================================= +// GroupedConvKernelKey - Unique identifier for a grouped convolution kernel +// ============================================================================= + +struct GroupedConvKernelKey +{ + // Signature fields + std::string dtype_in; + std::string dtype_wei; + std::string dtype_out; + std::string layout; // e.g., "nhwgc" + int ndim_spatial = 2; // 1, 2, or 3 + GroupedConvOp op = GroupedConvOp::Forward; + + // Tile configuration + int tile_m = 1; + int tile_n = 128; + int tile_k = 128; + + // Wave/warp configuration + int wave_m = 2; + int wave_n = 2; + int wave_k = 1; + int warp_m = 32; + int warp_n = 32; + int warp_k = 16; + + // Pipeline + std::string pipeline = "compv3"; + std::string scheduler = "intrawave"; + std::string epilogue = "cshuffle"; + + // ConvConfigBase parity fields + int vector_size_a = 4; + int vector_size_b = 8; + int vector_size_c = 8; + int block_per_cu = 1; + int num_wave_groups = 1; + int num_groups_to_merge = 1; + + // GPU architecture (for filter_by_arch) + std::string arch = "gfx942"; + + bool operator==(const GroupedConvKernelKey& other) const + { + return dtype_in == other.dtype_in && dtype_wei == other.dtype_wei && + dtype_out == other.dtype_out && layout == other.layout && + ndim_spatial == other.ndim_spatial && op == other.op && tile_m == other.tile_m && + tile_n == other.tile_n && tile_k == other.tile_k && wave_m == other.wave_m && + wave_n == other.wave_n && wave_k == other.wave_k && warp_m == other.warp_m && + warp_n == other.warp_n && warp_k == other.warp_k && pipeline == other.pipeline && + scheduler == other.scheduler && epilogue == other.epilogue && + vector_size_a == other.vector_size_a && vector_size_b == other.vector_size_b && + vector_size_c == other.vector_size_c && block_per_cu == other.block_per_cu && + num_wave_groups == other.num_wave_groups && + num_groups_to_merge == other.num_groups_to_merge && arch == other.arch; + } + + std::string to_string() const + { + std::string op_str; + switch(op) + { + case GroupedConvOp::Forward: op_str = "fwd"; break; + case GroupedConvOp::BackwardData: op_str = "bwd_data"; break; + case GroupedConvOp::BackwardWeight: op_str = "bwd_weight"; break; + } + return "grouped_conv_" + op_str + "_" + dtype_in + "_" + std::to_string(ndim_spatial) + + "d_" + std::to_string(tile_m) + "x" + std::to_string(tile_n) + "x" + + std::to_string(tile_k) + "_" + std::to_string(wave_m) + "x" + + std::to_string(wave_n) + "x" + std::to_string(wave_k) + "_" + + std::to_string(warp_m) + "x" + std::to_string(warp_n) + "x" + + std::to_string(warp_k) + "_" + pipeline; + } +}; + +struct GroupedConvKernelKeyHash +{ + std::size_t operator()(const GroupedConvKernelKey& key) const + { + std::size_t h = std::hash{}(key.dtype_in); + h ^= std::hash{}(key.layout) << 1; + h ^= std::hash{}(key.ndim_spatial) << 2; + h ^= std::hash{}(static_cast(key.op)) << 3; + h ^= std::hash{}(key.tile_m) << 4; + h ^= std::hash{}(key.tile_n) << 5; + h ^= std::hash{}(key.tile_k) << 6; + h ^= std::hash{}(key.wave_m) << 7; + h ^= std::hash{}(key.wave_n) << 8; + h ^= std::hash{}(key.warp_m) << 9; + h ^= std::hash{}(key.warp_n) << 10; + h ^= std::hash{}(key.pipeline) << 11; + h ^= std::hash{}(key.arch) << 12; + return h; + } +}; + +// ============================================================================= +// GroupedConvKernelInstance - Runtime representation of a kernel +// ============================================================================= + +// Forward declaration for shared_ptr type alias +class GroupedConvKernelInstance; +using GroupedConvKernelInstancePtr = std::shared_ptr; + +class GroupedConvKernelInstance +{ + public: + using RunFn = std::function; + + GroupedConvKernelInstance(const GroupedConvKernelKey& key, + const std::string& name, + RunFn run_fn) + : key_(key), name_(name), run_fn_(std::move(run_fn)) + { + } + + const GroupedConvKernelKey& key() const { return key_; } + const std::string& name() const { return name_; } + + float run(const GroupedConvProblem& problem, void* stream = nullptr) const + { + return run_fn_(problem, stream); + } + + bool matches(const GroupedConvProblem& problem) const + { + // Check if this kernel can handle the problem + return problem.op == key_.op; + } + + private: + GroupedConvKernelKey key_; + std::string name_; + RunFn run_fn_; +}; + +// ============================================================================= +// GroupedConvRegistry - Stores and manages grouped convolution kernels +// ============================================================================= + +class GroupedConvRegistry : public BaseRegistry +{ + using Base = BaseRegistry; + + public: + GroupedConvRegistry() = default; + + /// Singleton instance for global kernel registration + static GroupedConvRegistry& instance() + { + static GroupedConvRegistry registry; + return registry; + } + + /// Register kernels from a GroupedConvKernelSet (atomic batch registration) + bool register_set(const GroupedConvKernelSet& kernel_set, Priority priority = Priority::Normal) + { + // Build all instances first, then register under a single lock hold + // so readers never see a half-registered set. + std::vector>> + batch; + batch.reserve(kernel_set.declarations().size()); + + for(const auto& decl : kernel_set.declarations()) + { + GroupedConvKernelKey key; + key.dtype_in = decl.signature.dtype_in_; + key.dtype_wei = decl.signature.dtype_wei_; + key.dtype_out = decl.signature.dtype_out_; + key.layout = decl.signature.layout_; + key.ndim_spatial = decl.signature.num_dims_; + key.op = (decl.signature.conv_op_ == "forward") ? GroupedConvOp::Forward + : (decl.signature.conv_op_ == "bwd_data") ? GroupedConvOp::BackwardData + : GroupedConvOp::BackwardWeight; + key.tile_m = decl.algorithm.tile_m_; + key.tile_n = decl.algorithm.tile_n_; + key.tile_k = decl.algorithm.tile_k_; + key.wave_m = decl.algorithm.wave_m_; + key.wave_n = decl.algorithm.wave_n_; + key.wave_k = decl.algorithm.wave_k_; + key.warp_m = decl.algorithm.warp_m_; + key.warp_n = decl.algorithm.warp_n_; + key.warp_k = decl.algorithm.warp_k_; + key.pipeline = decl.algorithm.pipeline_; + key.scheduler = decl.algorithm.scheduler_; + key.epilogue = decl.algorithm.epilogue_; + key.vector_size_a = decl.algorithm.vector_a_; + key.vector_size_b = decl.algorithm.vector_b_; + key.vector_size_c = decl.algorithm.vector_c_; + key.block_per_cu = decl.algorithm.block_per_cu_; + key.num_wave_groups = decl.algorithm.num_wave_groups_; + key.num_groups_to_merge = decl.algorithm.num_groups_to_merge_; + key.arch = decl.arch; + + batch.emplace_back(key, + std::make_shared( + key, decl.name(), [](const GroupedConvProblem&, void*) -> float { + return 0.0f; + })); + } + + std::lock_guard lock(mutex()); + bool any_registered = false; + for(auto& [key, instance] : batch) + { + auto it = entries().find(key); + if(it == entries().end() || it->second.priority <= priority) + { + entries_mut()[key] = typename Base::Entry{std::move(instance), priority}; + any_registered = true; + } + } + return any_registered; + } + + /// Find the best kernel for a problem + const GroupedConvKernelInstance* find(const GroupedConvProblem& problem) const + { + std::lock_guard lock(mutex()); + const GroupedConvKernelInstance* best = nullptr; + Priority best_priority = Priority::Low; + + for(const auto& [key, entry] : entries()) + { + if(entry.instance->matches(problem)) + { + if(!best || entry.priority > best_priority) + { + best = entry.instance.get(); + best_priority = entry.priority; + } + } + } + + return best; + } + + /// Get all registered kernels + std::vector all_kernels() const + { + std::lock_guard lock(mutex()); + std::vector result; + for(const auto& [key, entry] : entries()) + { + result.push_back(entry.instance.get()); + } + return result; + } + + /// Export registry to JSON string + std::string export_json(bool include_statistics = false) const + { + // Note: get_name() acquires the mutex internally, so we must NOT hold + // the registry mutex here (std::mutex is not recursive). + std::string reg_name = get_name(); + + std::lock_guard lock(mutex()); + std::ostringstream json; + + json << "{\n"; + json << " \"metadata\": {\n"; + json << " \"registry_name\": \"" << json_escape(reg_name) << "\",\n"; + json << " \"total_kernels\": " << entries().size() << "\n"; + json << " }"; + + if(include_statistics && !entries().empty()) + { + std::map by_datatype; + std::map by_pipeline; + std::map by_arch; + + for(const auto& [key, entry] : entries()) + { + std::string dtype_key = key.dtype_in + "_" + key.dtype_wei + "_" + key.dtype_out; + by_datatype[dtype_key]++; + by_pipeline[key.pipeline]++; + by_arch[key.arch]++; + } + + json << ",\n \"statistics\": {\n"; + json << " \"by_datatype\": {"; + bool first = true; + for(const auto& [dtype, count] : by_datatype) + { + if(!first) + json << ","; + json << "\"" << json_escape(dtype) << "\":" << count; + first = false; + } + json << "},\n"; + json << " \"by_pipeline\": {"; + first = true; + for(const auto& [pipeline, count] : by_pipeline) + { + if(!first) + json << ","; + json << "\"" << json_escape(pipeline) << "\":" << count; + first = false; + } + json << "},\n"; + json << " \"by_arch\": {"; + first = true; + for(const auto& [arch, count] : by_arch) + { + if(!first) + json << ","; + json << "\"" << json_escape(arch) << "\":" << count; + first = false; + } + json << "}\n }"; + } + + json << ",\n \"kernels\": [\n"; + bool first = true; + for(const auto& [key, entry] : entries()) + { + if(!first) + json << ",\n"; + json << " " << export_kernel_json(*entry.instance); + first = false; + } + json << "\n ]\n"; + json << "}\n"; + + return json.str(); + } + + /// Export registry to JSON file + void export_json_to_file(const std::string& filename, bool include_statistics = false) const + { + std::string json_str = export_json(include_statistics); + std::ofstream file(filename); + if(!file.is_open()) + { + throw std::runtime_error("Failed to open file for export: " + filename); + } + file << json_str; + } + + /// Get kernels matching a predicate + std::vector + filter(std::function predicate) const + { + std::lock_guard lock(mutex()); + std::vector result; + for(const auto& [key, entry] : entries()) + { + if(predicate(*entry.instance)) + { + result.push_back(entry.instance.get()); + } + } + return result; + } + + /// Remove kernels not matching the arch + std::size_t filter_by_arch(const std::string& gpu_arch) + { + std::lock_guard lock(mutex()); + std::vector to_remove; + for(const auto& [key, entry] : entries()) + { + if(key.arch != gpu_arch) + { + to_remove.push_back(key); + } + } + for(const auto& key : to_remove) + { + entries_mut().erase(key); + } + return to_remove.size(); + } + + private: + static std::string json_escape(const std::string& str) + { + std::ostringstream oss; + for(char c : str) + { + switch(c) + { + case '"': oss << "\\\""; break; + case '\\': oss << "\\\\"; break; + case '\b': oss << "\\b"; break; + case '\f': oss << "\\f"; break; + case '\n': oss << "\\n"; break; + case '\r': oss << "\\r"; break; + case '\t': oss << "\\t"; break; + default: + if(c < 0x20) + { + oss << "\\u" << std::hex << std::setw(4) << std::setfill('0') << (int)c; + } + else + { + oss << c; + } + } + } + return oss.str(); + } + + static std::string export_kernel_json(const GroupedConvKernelInstance& kernel) + { + std::ostringstream json; + const auto& key = kernel.key(); + + std::string op_str; + switch(key.op) + { + case GroupedConvOp::Forward: op_str = "fwd"; break; + case GroupedConvOp::BackwardData: op_str = "bwd_data"; break; + case GroupedConvOp::BackwardWeight: op_str = "bwd_weight"; break; + } + + json << "{\n"; + json << " \"name\": \"" << json_escape(kernel.name()) << "\",\n"; + json << " \"signature\": {\n"; + json << " \"dtype_in\": \"" << json_escape(key.dtype_in) << "\",\n"; + json << " \"dtype_wei\": \"" << json_escape(key.dtype_wei) << "\",\n"; + json << " \"dtype_out\": \"" << json_escape(key.dtype_out) << "\",\n"; + json << " \"layout\": \"" << json_escape(key.layout) << "\",\n"; + json << " \"ndim_spatial\": " << key.ndim_spatial << ",\n"; + json << " \"op\": \"" << op_str << "\"\n"; + json << " },\n"; + json << " \"algorithm\": {\n"; + json << " \"tile_m\": " << key.tile_m << ",\n"; + json << " \"tile_n\": " << key.tile_n << ",\n"; + json << " \"tile_k\": " << key.tile_k << ",\n"; + json << " \"wave\": \"" << key.wave_m << "x" << key.wave_n << "x" << key.wave_k + << "\",\n"; + json << " \"warp\": \"" << key.warp_m << "x" << key.warp_n << "x" << key.warp_k + << "\",\n"; + json << " \"pipeline\": \"" << json_escape(key.pipeline) << "\",\n"; + json << " \"scheduler\": \"" << json_escape(key.scheduler) << "\",\n"; + json << " \"epilogue\": \"" << json_escape(key.epilogue) << "\",\n"; + json << " \"vector_sizes\": [" << key.vector_size_a << "," << key.vector_size_b + << "," << key.vector_size_c << "],\n"; + json << " \"block_per_cu\": " << key.block_per_cu << ",\n"; + json << " \"num_wave_groups\": " << key.num_wave_groups << ",\n"; + json << " \"num_groups_to_merge\": " << key.num_groups_to_merge << "\n"; + json << " },\n"; + json << " \"arch\": \"" << json_escape(key.arch) << "\"\n"; + json << " }"; + + return json.str(); + } +}; + +// ============================================================================= +// GroupedConvDispatcher - Selects and runs the best kernel for a problem +// ============================================================================= + +class GroupedConvDispatcher +{ + public: + enum class SelectionStrategy + { + PriorityBased, + Heuristic + }; + + using HeuristicFunction = std::function(const GroupedConvProblem&)>; + + explicit GroupedConvDispatcher(GroupedConvRegistry* registry) + : registry_(registry), strategy_(SelectionStrategy::PriorityBased) + { + } + + void set_strategy(SelectionStrategy s) { strategy_ = s; } + void set_heuristic(HeuristicFunction fn) { heuristic_ = std::move(fn); } + + /// Select the best kernel for a problem (does not run it) + const GroupedConvKernelInstance* select_kernel(const GroupedConvProblem& problem) const + { + if(strategy_ == SelectionStrategy::Heuristic) + return select_heuristic(problem); + return registry_->find(problem); + } + + /// Run convolution with automatic kernel selection (legacy - no buffers) + float run(const GroupedConvProblem& problem, void* stream = nullptr) + { + const auto* kernel = select_kernel(problem); + if(!kernel) + { + throw NoKernelFound("No suitable grouped convolution kernel found for problem: " + + problem.to_string()); + } + return kernel->run(problem, stream); + } + + /// Run convolution with buffer pointers and automatic kernel selection. + /// Sets the thread-local buffer context before dispatching to the kernel. + float run(const void* input_ptr, + const void* weight_ptr, + void* output_ptr, + const GroupedConvProblem& problem, + void* stream = nullptr, + int warmup = 3, + int repeat = 10) + { + const auto* kernel = select_kernel(problem); + if(!kernel) + { + throw NoKernelFound("No suitable grouped convolution kernel found for problem: " + + problem.to_string()); + } + g_conv_dispatch_buffers.input_ptr = input_ptr; + g_conv_dispatch_buffers.weight_ptr = weight_ptr; + g_conv_dispatch_buffers.output_ptr = output_ptr; + g_conv_dispatch_buffers.warmup = warmup; + g_conv_dispatch_buffers.repeat = repeat; + g_conv_dispatch_buffers.benchmarking = benchmarking_; + g_conv_dispatch_buffers.split_k = problem.split_k; + return kernel->run(problem, stream); + } + + /// Enable or disable GPU benchmarking (timing). + /// When disabled, kernels execute once with no timing overhead. + void set_benchmarking(bool enable) { benchmarking_ = enable; } + [[nodiscard]] bool benchmarking_enabled() const { return benchmarking_; } + + /// Alias kept for backward compatibility + const GroupedConvKernelInstance* select(const GroupedConvProblem& problem) const + { + return select_kernel(problem); + } + + private: + const GroupedConvKernelInstance* select_heuristic(const GroupedConvProblem& problem) const + { + if(!heuristic_) + return registry_->find(problem); + + auto ranked_names = heuristic_(problem); + auto all = registry_->all_kernels(); + for(const auto& name : ranked_names) + { + for(const auto* kernel : all) + { + if(kernel->name().find(name) != std::string::npos && kernel->matches(problem)) + { + return kernel; + } + } + } + return registry_->find(problem); + } + + GroupedConvRegistry* registry_; + SelectionStrategy strategy_; + HeuristicFunction heuristic_; + bool benchmarking_ = true; +}; + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/grouped_conv_utils.hpp b/dispatcher/include/ck_tile/dispatcher/grouped_conv_utils.hpp new file mode 100644 index 0000000000..c817d36673 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher/grouped_conv_utils.hpp @@ -0,0 +1,324 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/** + * @file grouped_conv_utils.hpp + * @brief CK Tile Grouped Convolution Dispatcher Utilities + */ + +#pragma once + +#include "ck_tile/dispatcher/grouped_conv_config.hpp" +#include "ck_tile/dispatcher/grouped_conv_kernel_decl.hpp" +#include "ck_tile/dispatcher/grouped_conv_problem.hpp" +#include "ck_tile/dispatcher/grouped_conv_registry.hpp" +#include "ck_tile/dispatcher/arch_filter.hpp" +#include "ck_tile/dispatcher/utils.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace ck_tile { +namespace dispatcher { + +using GroupedConvSig = grouped_conv_decl::GroupedConvSignature; +using GroupedConvAlgo = grouped_conv_decl::GroupedConvAlgorithm; + +namespace grouped_conv_utils { + +inline GroupedConvKernelDecl create_grouped_conv2d_fwd(const std::string& dtype = "fp16", + int tile_n = 128, + int tile_k = 128, + const std::string& arch = "gfx942") +{ + return GroupedConvKernelDecl( + GroupedConvSig().dtype(dtype).layout("nhwc").conv_type("forward").dims(2), + GroupedConvAlgo() + .tile(1, tile_n, tile_k) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv4") + .vector_sizes(4, 8, 8), + arch); +} + +inline GroupedConvKernelDecl create_grouped_conv3d_fwd(const std::string& dtype = "fp16", + int tile_n = 64, + int tile_k = 64, + const std::string& arch = "gfx942") +{ + return GroupedConvKernelDecl( + GroupedConvSig().dtype(dtype).layout("ndhwc").conv_type("forward").dims(3), + GroupedConvAlgo() + .tile(1, tile_n, tile_k) + .wave(2, 2, 1) + .warp(16, 16, 32) + .pipeline("compv3") + .vector_sizes(4, 8, 8), + arch); +} + +inline GroupedConvKernelDecl create_grouped_conv2d_bwd_data(const std::string& dtype = "fp16", + int tile_n = 128, + int tile_k = 128, + const std::string& arch = "gfx942") +{ + return GroupedConvKernelDecl( + GroupedConvSig().dtype(dtype).layout("nhwc").conv_type("bwd_data").dims(2), + GroupedConvAlgo() + .tile(1, tile_n, tile_k) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .vector_sizes(4, 8, 8), + arch); +} + +inline GroupedConvKernelDecl create_grouped_conv2d_bwd_weight(const std::string& dtype = "fp16", + int tile_n = 128, + int tile_k = 128, + const std::string& arch = "gfx942") +{ + return GroupedConvKernelDecl( + GroupedConvSig().dtype(dtype).layout("nhwc").conv_type("bwd_weight").dims(2), + GroupedConvAlgo() + .tile(1, tile_n, tile_k) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv3") + .memory_op("atomic_add") + .vector_sizes(4, 8, 8), + arch); +} + +inline GroupedConvProblem create_grouped_conv2d_problem(int N, + int C, + int K, + int Hi, + int Wi, + int Y, + int X, + int stride = 1, + int padding = 0, + GroupedConvOp op = GroupedConvOp::Forward) +{ + GroupedConvProblem p; + p.N = N; + p.C = C; + p.K = K; + p.G = 1; + p.input_spatial = {1, Hi, Wi}; + p.filter_spatial = {1, Y, X}; + p.stride = {1, stride, stride}; + p.padding = {0, padding, padding}; + p.dilation = {1, 1, 1}; + p.op = op; + p.compute_output_size(); + return p; +} + +inline GroupedConvProblem create_grouped_conv3d_problem(int N, + int C, + int K, + int Di, + int Hi, + int Wi, + int Z, + int Y, + int X, + int stride = 1, + int padding = 0, + GroupedConvOp op = GroupedConvOp::Forward) +{ + GroupedConvProblem p; + p.N = N; + p.C = C; + p.K = K; + p.G = 1; + p.input_spatial = {Di, Hi, Wi}; + p.filter_spatial = {Z, Y, X}; + p.stride = {stride, stride, stride}; + p.padding = {padding, padding, padding}; + p.dilation = {1, 1, 1}; + p.op = op; + p.compute_output_size(); + return p; +} + +inline GroupedConvProblem create_depthwise_grouped_conv2d_problem( + int N, int C, int Hi, int Wi, int Y, int X, int stride = 1, int padding = 0) +{ + GroupedConvProblem p; + p.N = N; + p.C = C; + p.K = C; + p.G = C; + p.input_spatial = {1, Hi, Wi}; + p.filter_spatial = {1, Y, X}; + p.stride = {1, stride, stride}; + p.padding = {0, padding, padding}; + p.dilation = {1, 1, 1}; + p.op = GroupedConvOp::Forward; + p.compute_output_size(); + return p; +} + +inline void print_pattern_docs(std::ostream& os = std::cout) +{ + os << "Grouped Convolution Pattern Documentation\n"; + os << "==========================================\n"; + os << "Signature patterns: dtype, layout, conv_type (forward/bwd_data/bwd_weight), dims " + "(2/3)\n"; + os << "Algorithm patterns: tile(M,N,K), wave(M,N,K), warp(M,N,K), pipeline, vector_sizes\n"; + os << "Arch patterns: gfx942, gfx90a, gfx950, or '*' for all\n"; +} + +inline void print_grouped_conv_kernel_decl(const GroupedConvKernelDecl& decl, + std::ostream& os = std::cout) +{ + os << "GroupedConvKernelDecl: " << decl.name() << "\n"; + os << " Signature: dtype=" << decl.signature.dtype_in_ << ", layout=" << decl.signature.layout_ + << ", conv_type=" << decl.signature.conv_op_ << ", dims=" << decl.signature.num_dims_ + << "\n"; + os << " Algorithm: tile=" << decl.algorithm.tile_m_ << "x" << decl.algorithm.tile_n_ << "x" + << decl.algorithm.tile_k_ << ", wave=" << decl.algorithm.wave_m_ << "x" + << decl.algorithm.wave_n_ << "x" << decl.algorithm.wave_k_ + << ", warp=" << decl.algorithm.warp_m_ << "x" << decl.algorithm.warp_n_ << "x" + << decl.algorithm.warp_k_ << ", pipeline=" << decl.algorithm.pipeline_ << "\n"; + os << " Arch: " << decl.arch << "\n"; +} + +inline void print_grouped_conv_problem(const GroupedConvProblem& p, std::ostream& os = std::cout) +{ + os << p.to_string() << "\n"; + os << " FLOPs: " << std::scientific << p.get_flops() << "\n"; +} + +inline GroupedConvKernelSet build_grouped_conv2d_fwd_set(const std::string& dtype = "fp16", + const std::string& arch = "gfx942") +{ + GroupedConvKernelSet set; + auto decl1 = create_grouped_conv2d_fwd(dtype, 128, 128, arch); + set.add(decl1.signature, decl1.algorithm, decl1.arch); + auto decl2 = create_grouped_conv2d_fwd(dtype, 256, 256, arch); + set.add(decl2.signature, decl2.algorithm, decl2.arch); + return set; +} + +inline GroupedConvKernelSet build_grouped_conv2d_full_set(const std::string& dtype = "fp16", + const std::string& arch = "gfx942") +{ + GroupedConvKernelSet set; + set.merge(build_grouped_conv2d_fwd_set(dtype, arch)); + auto bwd_data = create_grouped_conv2d_bwd_data(dtype, 128, 128, arch); + set.add(bwd_data.signature, bwd_data.algorithm, bwd_data.arch); + auto bwd_weight = create_grouped_conv2d_bwd_weight(dtype, 128, 128, arch); + set.add(bwd_weight.signature, bwd_weight.algorithm, bwd_weight.arch); + return set; +} + +struct ValidationResult +{ + bool passed = false; + float max_abs_diff = 0.0f; + float max_rel_diff = 0.0f; + float rtol = 1e-3f; + float atol = 1e-3f; + + void print(std::ostream& os = std::cout) const + { + os << "ValidationResult: " << (passed ? "PASSED" : "FAILED") << "\n"; + os << " max_abs_diff: " << max_abs_diff << ", max_rel_diff: " << max_rel_diff << "\n"; + os << " rtol: " << rtol << ", atol: " << atol << "\n"; + } +}; + +template +inline ValidationResult validate_buffers( + const T* result, const T* reference, size_t count, float rtol = 1e-3f, float atol = 1e-3f) +{ + ValidationResult vr; + vr.rtol = rtol; + vr.atol = atol; + vr.passed = true; + + for(size_t i = 0; i < count; ++i) + { + float r = static_cast(result[i]); + float ref = static_cast(reference[i]); + float abs_diff = std::abs(r - ref); + float rel_diff = (std::abs(ref) > 1e-10f) ? (abs_diff / std::abs(ref)) : 0.0f; + + vr.max_abs_diff = std::max(vr.max_abs_diff, abs_diff); + vr.max_rel_diff = std::max(vr.max_rel_diff, rel_diff); + + float threshold = atol + rtol * std::abs(ref); + if(abs_diff > threshold) + { + vr.passed = false; + } + } + + return vr; +} + +struct BenchmarkResult +{ + std::string kernel_name; + float time_ms = 0.0f; + float tflops = 0.0f; + int warmup_runs = 0; + int benchmark_runs = 0; + + void print(std::ostream& os = std::cout) const + { + os << "BenchmarkResult: " << kernel_name << "\n"; + os << " time_ms: " << time_ms << ", tflops: " << tflops << "\n"; + os << " warmup_runs: " << warmup_runs << ", benchmark_runs: " << benchmark_runs << "\n"; + } +}; + +inline float calc_tflops(double flops, float time_ms) +{ + return static_cast(flops / (time_ms * 1e9)); +} + +inline double calculate_conv_tflops(const GroupedConvProblem& problem, double time_ms) +{ + return problem.get_flops() / (time_ms * 1e9); +} + +} // namespace grouped_conv_utils + +namespace examples { +inline int basic_grouped_conv_example_main(const std::string& example_name) +{ + std::cout << "=== " << example_name << " ===\n"; + + // Create a grouped convolution problem + auto problem = grouped_conv_utils::create_grouped_conv2d_problem( + 32, 64, 128, 28, 28, 3, 3, 1, 1, GroupedConvOp::Forward); + + grouped_conv_utils::print_grouped_conv_problem(problem); + + // Create and print a kernel declaration + auto decl = grouped_conv_utils::create_grouped_conv2d_fwd("fp16", 128, 128, "gfx942"); + grouped_conv_utils::print_grouped_conv_kernel_decl(decl); + + // Build and print kernel set + auto kernel_set = grouped_conv_utils::build_grouped_conv2d_fwd_set("fp16", "gfx942"); + kernel_set.print(); + + return 0; +} +} // namespace examples + +} // namespace dispatcher +} // namespace ck_tile diff --git a/dispatcher/include/ck_tile/dispatcher/problem.hpp b/dispatcher/include/ck_tile/dispatcher/problem.hpp index 437511d1ba..5bffb56b49 100644 --- a/dispatcher/include/ck_tile/dispatcher/problem.hpp +++ b/dispatcher/include/ck_tile/dispatcher/problem.hpp @@ -98,7 +98,7 @@ struct Problem /** * Create Problem by inferring MNK from tensor shapes. * - * For GEMM: C[M,N] = A[M,K] × B[K,N] + * For GEMM: C[M,N] = A[M,K] x B[K,N] * * @param a_shape Shape of matrix A (M x K, or K x M if transposed) * @param b_shape Shape of matrix B (K x N, or N x K if transposed) @@ -113,7 +113,7 @@ struct Problem [[nodiscard]] static Problem from_shapes(TensorShape a_shape, TensorShape b_shape, TensorShape c_shape) { - // For C = A × B: + // For C = A x B: // A: [M, K] (or [K, M] if transposed) // B: [K, N] (or [N, K] if transposed) // C: [M, N] @@ -164,7 +164,7 @@ struct Problem * @throws std::invalid_argument if dimensions are inconsistent * * Example: - * // A[512,256] × B[256,1024] = C[512,1024] + * // A[512,256] x B[256,1024] = C[512,1024] * auto problem = Problem::from_dimensions(512, 256, 256, 1024, 512, 1024); */ [[nodiscard]] static Problem from_dimensions(std::int64_t a_rows, @@ -188,7 +188,7 @@ struct Problem * @throws std::invalid_argument if K dimensions don't match * * Example: - * // A[512,256] × B[256,1024] = C[512,1024] + * // A[512,256] x B[256,1024] = C[512,1024] * auto problem = Problem::from_ab(512, 256, 256, 1024); */ [[nodiscard]] static Problem diff --git a/dispatcher/include/ck_tile/dispatcher/registry.hpp b/dispatcher/include/ck_tile/dispatcher/registry.hpp index 93d1eb9f64..4f34e589ea 100644 --- a/dispatcher/include/ck_tile/dispatcher/registry.hpp +++ b/dispatcher/include/ck_tile/dispatcher/registry.hpp @@ -7,38 +7,20 @@ * Central registry for all available kernel instances with priority-based * ordering and efficient lookup. * - * Features: - * - Thread-safe registration and lookup - * - Priority-based ordering (High, Normal, Low) - * - Lookup by name or KernelKey - * - Filter by problem compatibility - * - Supports both singleton and multiple instance patterns - * - * Usage (Singleton - backward compatible): - * auto& registry = Registry::instance(); - * registry.register_kernel(kernel, Priority::High); - * auto kernel = registry.lookup("kernel_name"); - * - * Usage (Multiple registries): - * Registry fp16_registry; - * Registry bf16_registry; - * fp16_registry.register_kernel(fp16_kernel, Priority::High); - * bf16_registry.register_kernel(bf16_kernel, Priority::High); - * - * Dispatcher fp16_dispatcher(&fp16_registry); - * Dispatcher bf16_dispatcher(&bf16_registry); + * Derives from BaseRegistry for shared logic (thread safety, naming, priority, + * merge) while keeping GEMM-specific APIs (lookup by KernelKey, filter_by_arch, + * JSON export, auto-export). * * Status: Production ready, thread-safe */ #pragma once +#include "ck_tile/dispatcher/base_registry.hpp" #include "ck_tile/dispatcher/kernel_instance.hpp" #include "ck_tile/dispatcher/kernel_key.hpp" #include -#include #include -#include #include #include @@ -47,20 +29,16 @@ namespace dispatcher { /// Registry: Central mapping from kernel configurations to executable instances /// Thread-safe kernel registration and lookup -/// Supports both singleton pattern and multiple independent instances -class Registry +/// Derives from BaseRegistry for shared functionality +class Registry : public BaseRegistry { + using Base = BaseRegistry; + public: - /// Priority levels for conflict resolution when multiple kernels have same key - enum class Priority - { - Low = 0, - Normal = 1, - High = 2 - }; + // Re-export Priority from the shared enum for backward compatibility + using Priority = ck_tile::dispatcher::Priority; /// Default constructor - creates an empty registry instance - /// Use this to create independent registries for different kernel sets Registry(); /// Destructor - triggers auto-export if enabled @@ -72,106 +50,51 @@ class Registry /// Move assignment Registry& operator=(Registry&& other) noexcept; - // Prevent copying (registries contain shared_ptrs that shouldn't be duplicated) + // Prevent copying Registry(const Registry&) = delete; Registry& operator=(const Registry&) = delete; /// Register a kernel instance with the registry - /// @param instance Kernel instance to register - /// @param priority Priority level for conflict resolution (default: Normal) - /// @return true if registered successfully, false if duplicate with higher priority exists bool register_kernel(KernelInstancePtr instance, Priority priority = Priority::Normal); /// Lookup a kernel by its string identifier - /// @param identifier Kernel identifier string - /// @return Kernel instance if found, nullptr otherwise [[nodiscard]] KernelInstancePtr lookup(const std::string& identifier) const; /// Lookup a kernel by its KernelKey - /// @param key Kernel configuration key - /// @return Kernel instance if found, nullptr otherwise [[nodiscard]] KernelInstancePtr lookup(const KernelKey& key) const; /// Get all registered kernels - /// @return Vector of all kernel instances [[nodiscard]] std::vector get_all() const; /// Get all kernels matching a predicate - /// @param predicate Function to filter kernels - /// @return Vector of matching kernel instances [[nodiscard]] std::vector filter(std::function predicate) const; - /// Get number of registered kernels - [[nodiscard]] std::size_t size() const; - - /// Check if registry is empty - [[nodiscard]] bool empty() const; - - /// Clear all registered kernels - void clear(); - - /// Get registry name (for logging/debugging) - [[nodiscard]] const std::string& get_name() const; - - /// Set registry name (for logging/debugging) - void set_name(const std::string& name); + // size(), empty(), clear(), get_name(), set_name(), merge_from() inherited from Base /// Export registry to JSON string - /// @param include_statistics Whether to include kernel statistics breakdown - /// @return JSON string with all kernel metadata [[nodiscard]] std::string export_json(bool include_statistics = true) const; /// Export registry to JSON file - /// @param filename Output filename - /// @param include_statistics Whether to include kernel statistics breakdown - /// @return true if export succeeded, false otherwise bool export_json_to_file(const std::string& filename, bool include_statistics = true) const; - /// Enable automatic JSON export on kernel registration - /// @param filename Output filename for auto-export - /// @param include_statistics Whether to include statistics in auto-export - /// @param export_on_every_registration If true, exports after every registration (default). - /// If false, only exports on destruction. void enable_auto_export(const std::string& filename, bool include_statistics = true, bool export_on_every_registration = true); - /// Disable automatic JSON export void disable_auto_export(); - /// Check if auto-export is enabled [[nodiscard]] bool is_auto_export_enabled() const; - /// Merge kernels from another registry into this one - /// @param other Registry to merge from - /// @param priority Priority for merged kernels (default: Normal) - /// @return Number of kernels successfully merged - std::size_t merge_from(const Registry& other, Priority priority = Priority::Normal); - /// Filter kernels in-place by architecture - /// @param gpu_arch Target GPU architecture string (e.g., "gfx942") - /// @return Number of kernels removed std::size_t filter_by_arch(const std::string& gpu_arch); - /// Get singleton instance of the global registry (backward compatible) - /// This is the default registry used when no specific registry is provided + /// Get singleton instance static Registry& instance(); private: - struct RegistryEntry - { - KernelInstancePtr instance; - Priority priority; - }; - - /// Perform auto-export if enabled void perform_auto_export(); - mutable std::mutex mutex_; - std::unordered_map kernels_; - std::string name_; - // Auto-export configuration bool auto_export_enabled_ = false; std::string auto_export_filename_; @@ -179,7 +102,7 @@ class Registry bool auto_export_on_every_registration_ = true; }; -/// Shared pointer type for registries (useful for managing lifetime) +/// Shared pointer type for registries using RegistryPtr = std::shared_ptr; /// Create a new registry instance (factory function) diff --git a/dispatcher/include/ck_tile/dispatcher_conv.hpp b/dispatcher/include/ck_tile/dispatcher_conv.hpp new file mode 100644 index 0000000000..46d14f90f3 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher_conv.hpp @@ -0,0 +1,18 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// Grouped Convolution-only dispatcher header -- minimal include for conv operations. + +#pragma once + +// Core (needed by all ops) +#include "ck_tile/dispatcher/base_registry.hpp" +#include "ck_tile/dispatcher/dispatcher_error.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +// Grouped Convolution +#include "ck_tile/dispatcher/grouped_conv_config.hpp" +#include "ck_tile/dispatcher/grouped_conv_problem.hpp" +#include "ck_tile/dispatcher/grouped_conv_kernel_decl.hpp" +#include "ck_tile/dispatcher/grouped_conv_registry.hpp" +#include "ck_tile/dispatcher/grouped_conv_utils.hpp" diff --git a/dispatcher/include/ck_tile/dispatcher_gemm.hpp b/dispatcher/include/ck_tile/dispatcher_gemm.hpp new file mode 100644 index 0000000000..79317c7399 --- /dev/null +++ b/dispatcher/include/ck_tile/dispatcher_gemm.hpp @@ -0,0 +1,22 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// GEMM-only dispatcher header -- minimal include for GEMM operations. + +#pragma once + +// Core (needed by all ops) +#include "ck_tile/dispatcher/base_registry.hpp" +#include "ck_tile/dispatcher/dispatcher_error.hpp" +#include "ck_tile/dispatcher/example_args.hpp" + +// GEMM +#include "ck_tile/dispatcher/kernel_key.hpp" +#include "ck_tile/dispatcher/kernel_config.hpp" +#include "ck_tile/dispatcher/kernel_decl.hpp" +#include "ck_tile/dispatcher/kernel_instance.hpp" +#include "ck_tile/dispatcher/problem.hpp" +#include "ck_tile/dispatcher/registry.hpp" +#include "ck_tile/dispatcher/dispatcher.hpp" +#include "ck_tile/dispatcher/json_export.hpp" +#include "ck_tile/dispatcher/utils.hpp" diff --git a/dispatcher/python/CMakeLists.txt b/dispatcher/python/CMakeLists.txt index e57678952e..71634fa926 100644 --- a/dispatcher/python/CMakeLists.txt +++ b/dispatcher/python/CMakeLists.txt @@ -3,7 +3,7 @@ # This directory contains Python utilities for the dispatcher examples. # The main utility file is ctypes_utils.py which is used by GEMM Python examples. -# Conv Python examples use their own conv_utils.py in the examples directory. +# Grouped conv Python examples use grouped_conv_utils.py in this directory. # No build targets needed - these are pure Python utilities. message(STATUS "Python utilities directory configured (no build targets)") diff --git a/dispatcher/python/README.md b/dispatcher/python/README.md index 9286acbf72..edbc7acc9d 100644 --- a/dispatcher/python/README.md +++ b/dispatcher/python/README.md @@ -4,6 +4,19 @@ This directory contains Python utilities used by the dispatcher examples. ## Contents +### Shared Utilities (used by both GEMM and Grouped Conv) + +- `dispatcher_common.py` - Shared dispatcher infrastructure + - Path helpers (`get_dispatcher_root`, `get_build_dir`, etc.) + - `ValidationResultBase` - Structured validation feedback + - `validate_wave_config`, `validate_warp_tile_config`, `validate_trait_combo` + - `auto_correct_wave`, `auto_correct_trait` - Auto-correction helpers + - `Colors` - Cross-platform ANSI color support + - `print_phase`, `print_success`, `print_error`, `print_info` - Phased output + - `cleanup_generated_kernels` - Cleanup helper + +### GEMM Utilities + - `ctypes_utils.py` - Core ctypes utilities for GEMM Python examples - `KernelConfig` - Kernel configuration dataclass - `setup_gemm_dispatcher()` - Setup dispatcher with auto-correction @@ -11,11 +24,15 @@ This directory contains Python utilities used by the dispatcher examples. - `GemmRunner` - GPU execution helper - Auto-correction and validation utilities -- `conv_utils.py` - Core utilities for Conv Python examples - - `ConvSignature`, `ConvAlgorithm` - Convolution configuration - - `ConvProblem` - Problem definition - - `GpuConvRunner` - GPU execution helper - - `EnhancedConvCodegenRunner` - Kernel codegen utilities +### Grouped Convolution Utilities + +- `grouped_conv_utils.py` - Utilities for grouped convolution + - `GroupedConvValidationResult` - Validation result (extends `ValidationResultBase`) + - `validate_grouped_conv_config` - Validate a grouped conv config + - `auto_correct_grouped_conv_config` - Auto-correct invalid configs + - `get_grouped_conv_default_config` - Get default config for a variant + - `GroupedConvDataType` - Data type enum (FP16, BF16, FP32, FP8, BF8, INT8) + - `format_grouped_conv_summary` - Human-readable config summary ## Usage @@ -36,21 +53,26 @@ from ctypes_utils import ( ) ``` -### Conv Examples - -The Conv Python examples in `dispatcher/examples/conv/python/` import: +### Grouped Conv Usage ```python import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python")) -from conv_utils import ( - ConvSignature, - ConvAlgorithm, - ConvProblem, - GpuConvRunner, +from grouped_conv_utils import ( + validate_grouped_conv_config, + auto_correct_grouped_conv_config, + get_grouped_conv_default_config, + GroupedConvDataType, ) + +# Get a default config +config = get_grouped_conv_default_config(variant="forward", arch="gfx942") + +# Validate +result = validate_grouped_conv_config(config) +print(f"Valid: {result.is_valid}") ``` ## Requirements diff --git a/dispatcher/python/ctypes_utils.py b/dispatcher/python/ctypes_utils.py index 821fc2b08d..c11aaca835 100644 --- a/dispatcher/python/ctypes_utils.py +++ b/dispatcher/python/ctypes_utils.py @@ -37,6 +37,43 @@ import multiprocessing import time +# ============================================================================= +# GPU Architecture Auto-Detection +# ============================================================================= + +_detected_arch: Optional[str] = None + + +def detect_gpu_arch(fallback: str = "gfx942") -> str: + """ + Auto-detect the GPU architecture by querying rocminfo. + + Caches the result after the first call. Falls back to `fallback` if + detection fails (e.g. no GPU, rocminfo not installed). + """ + global _detected_arch + if _detected_arch is not None: + return _detected_arch + + try: + result = subprocess.run( + ["/opt/rocm/bin/rocminfo"], capture_output=True, text=True, timeout=10 + ) + for line in result.stdout.splitlines(): + stripped = line.strip() + if stripped.startswith("Name:") and "gfx" in stripped: + # Extract e.g. "gfx950" from "Name: gfx950" + name = stripped.split(":", 1)[1].strip() + if name.startswith("gfx") and name[3:].isdigit(): + _detected_arch = name + return _detected_arch + except Exception: + pass + + _detected_arch = fallback + return _detected_arch + + # ============================================================================= # Path Configuration # ============================================================================= @@ -159,9 +196,9 @@ class ValidationResult: def print_result(self, indent: str = " "): """Print validation result.""" if self.is_valid: - print(f"{indent}✓ Configuration valid") + print(f"{indent}OK Configuration valid") else: - print(f"{indent}⚠ Configuration has issues:") + print(f"{indent}WARNING Configuration has issues:") for err in self.errors: print(f"{indent} - {err}") @@ -300,7 +337,7 @@ def auto_correct_kernel_config( # Check each fix and describe what changed if "scheduler" in fixes and fixes["scheduler"] != config.scheduler: corrections.append( - f"Scheduler: {config.scheduler} → {fixes['scheduler']} " + f"Scheduler: {config.scheduler} -> {fixes['scheduler']} " f"('{config.scheduler}' not supported with pipeline={config.pipeline}, epilogue={config.epilogue})" ) @@ -309,7 +346,7 @@ def auto_correct_kernel_config( new_wave = f"[{fixes.get('wave_m', config.wave_m)}, {fixes.get('wave_n', config.wave_n)}, {fixes.get('wave_k', config.wave_k)}]" if old_wave != new_wave: corrections.append( - f"Wave config: {old_wave} → {new_wave} " + f"Wave config: {old_wave} -> {new_wave} " f"(original not supported on {config.gfx_arch})" ) @@ -318,7 +355,7 @@ def auto_correct_kernel_config( new_warp = f"[{fixes.get('warp_m', config.warp_m)}, {fixes.get('warp_n', config.warp_n)}, {fixes.get('warp_k', config.warp_k)}]" if old_warp != new_warp: corrections.append( - f"Warp tile: {old_warp} → {new_warp} " + f"Warp tile: {old_warp} -> {new_warp} " f"(original not supported for {config.dtype_a} on {config.gfx_arch})" ) @@ -386,13 +423,13 @@ def print_auto_correction( indent: Indentation for output """ if not corrections: - print(f"{indent}✓ Configuration valid - no corrections needed") + print(f"{indent}OK Configuration valid - no corrections needed") return - print(f"\n{indent}⚠ AUTO-CORRECTION APPLIED:") + print(f"\n{indent}WARNING AUTO-CORRECTION APPLIED:") print(f"{indent}" + "-" * 50) for correction in corrections: - print(f"{indent} • {correction}") + print(f"{indent} - {correction}") print(f"{indent}" + "-" * 50) print() @@ -976,6 +1013,226 @@ def _run_codegen_subprocess(args: Dict[str, Any]) -> CodegenResult: ) +def _run_hipcc_subprocess(args: dict) -> Tuple[bool, Optional[Path], str]: + """Module-level function to run hipcc compilation in parallel.""" + import subprocess + from pathlib import Path + + compile_cmd = args["compile_cmd"] + link_cmd = args["link_cmd"] + lib_path = Path(args["lib_path"]) + + try: + res_c = subprocess.run(compile_cmd, capture_output=True, text=True, timeout=300) + if res_c.returncode != 0: + return False, None, f"Compile failed: {res_c.stderr[:200]}" + + res_l = subprocess.run(link_cmd, capture_output=True, text=True, timeout=300) + if res_l.returncode != 0: + return False, None, f"Link failed: {res_l.stderr[:200]}" + + return True, lib_path, "" + except subprocess.TimeoutExpired: + return False, None, "Timeout" + except Exception as e: + return False, None, str(e) + + +def _generate_single_kernel_subprocess(args: dict) -> Tuple[bool, Optional[str], str]: + """Module-level function: generate ONE kernel .hpp via --config JSON file. + + Used by setup_multiple_gemm_dispatchers for per-config parallel codegen. + Returns (success, header_path_or_None, error_msg). + """ + import subprocess + import json + import tempfile + import os + from pathlib import Path + + try: + out_dir = Path(args["output_dir"]) + out_dir.mkdir(parents=True, exist_ok=True) + + # Write the single-config JSON to a temp file + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(args["tile_config_json"], f) + config_file = f.name + + cmd = [ + args["python"], + str(args["codegen_script"]), + "--output-dir", + str(out_dir), + "--datatype", + args["dtype"], + "--layout", + args["layout"], + "--gpu-target", + args["gpu_target"], + "--config", + config_file, + "--variants", + "standard", + ] + + res = subprocess.run(cmd, capture_output=True, text=True, timeout=300) + os.unlink(config_file) + + if res.returncode != 0: + return False, None, f"Codegen failed: {res.stderr[:200]}" + + # Find the generated .hpp using the expected name pattern + pattern = args["hpp_glob_pattern"] + matches = sorted(out_dir.glob(pattern)) + if matches: + return True, str(matches[0]), "" + else: + return False, None, f"No .hpp matching {pattern} after codegen" + + except Exception as e: + return False, None, str(e) + + +def _parse_triplet(text: str) -> Optional[Tuple[int, int, int]]: + parts = text.split("x") + if len(parts) != 3: + return None + try: + return (int(parts[0]), int(parts[1]), int(parts[2])) + except ValueError: + return None + + +def _parse_gemm_header_metadata(header: Path) -> Optional[Dict[str, Any]]: + """ + Parse GEMM header name into configuration metadata. + + Expected stem format: + gemm_{dtype}_{layout}_{pipeline}_{epilogue}_{scheduler} + _{pad_m}_{pad_n}_{pad_k}_{persistent} + _{tile_m}x{tile_n}x{tile_k}_{wave_m}x{wave_n}x{wave_k}_{warp_m}x{warp_n}x{warp_k} + """ + parts = header.stem.split("_") + if len(parts) < 13 or parts[0] != "gemm": + return None + + tile = _parse_triplet(parts[10]) + wave = _parse_triplet(parts[11]) + warp = _parse_triplet(parts[12]) + if tile is None or wave is None or warp is None: + return None + + def _as_bool(v: str) -> bool: + return v.lower() == "true" + + return { + "dtype": parts[1], + "layout": parts[2], + "pipeline": parts[3], + "epilogue": parts[4], + "scheduler": parts[5], + "pad_m": _as_bool(parts[6]), + "pad_n": _as_bool(parts[7]), + "pad_k": _as_bool(parts[8]), + "persistent": _as_bool(parts[9]), + "tile": tile, + "wave": wave, + "warp": warp, + } + + +def _generate_arch_valid_gemm_headers( + python_exe: str, + codegen_script: Path, + output_dir: Path, + dtype: str, + layout: str, + gpu_target: str, + variant: str = "standard", +) -> Tuple[bool, List[Path], str]: + """Generate (or reuse) an arch-filtered kernel catalog for fallback selection.""" + output_dir.mkdir(parents=True, exist_ok=True) + pattern = f"gemm_{dtype}_{layout}_*.hpp" + existing = sorted(output_dir.glob(pattern)) + if existing: + return True, existing, "" + + cmd = [ + python_exe, + str(codegen_script), + "--output-dir", + str(output_dir), + "--datatype", + dtype, + "--layout", + layout, + "--gpu-target", + gpu_target, + "--variants", + variant, + ] + res = subprocess.run(cmd, capture_output=True, text=True, timeout=600) + if res.returncode != 0: + err = (res.stderr or res.stdout or "").strip()[:500] + return False, [], f"Catalog codegen failed: {err}" + + generated = sorted(output_dir.glob(pattern)) + if not generated: + return False, [], "Catalog codegen produced no GEMM headers" + return True, generated, "" + + +def _select_best_arch_valid_gemm_header( + config: "KernelConfig", + headers: List[Path], +) -> Tuple[Optional[Path], Optional[Dict[str, Any]]]: + """Choose nearest arch-valid header for a requested GEMM config.""" + best: Optional[Path] = None + best_meta: Optional[Dict[str, Any]] = None + best_score: Optional[Tuple[int, int, int, int, int, int]] = None + + for h in headers: + meta = _parse_gemm_header_metadata(h) + if meta is None: + continue + if meta["dtype"] != config.dtype_a or meta["layout"] != config.layout: + continue + + tile = meta["tile"] + wave = meta["wave"] + warp = meta["warp"] + tile_delta = ( + abs(tile[0] - config.tile_m) + + abs(tile[1] - config.tile_n) + + abs(tile[2] - config.tile_k) + ) + wave_delta = ( + abs(wave[0] - config.wave_m) + + abs(wave[1] - config.wave_n) + + abs(wave[2] - config.wave_k) + ) + warp_delta = ( + abs(warp[0] - config.warp_m) + + abs(warp[1] - config.warp_n) + + abs(warp[2] - config.warp_k) + ) + score = ( + 0 if meta["pipeline"] == config.pipeline else 1, + 0 if meta["scheduler"] == config.scheduler else 1, + 0 if meta["epilogue"] == config.epilogue else 1, + tile_delta, + wave_delta, + warp_delta, + ) + if best_score is None or score < best_score: + best_score = score + best = h + best_meta = meta + + return best, best_meta + + # ============================================================================= # Preshuffle Utilities # ============================================================================= @@ -1319,7 +1576,7 @@ class CodegenRunner: result = future.result() results.append(result) if verbose: - status = "✓" if result.success else "✗" + status = "OK" if result.success else "FAIL" print( f" {status} {variant}: {result.kernel_count} kernels in {result.elapsed_seconds:.2f}s" ) @@ -1337,7 +1594,7 @@ class CodegenRunner: ) ) if verbose: - print(f" ✗ {variant}: FAILED - {e}") + print(f" FAIL {variant}: FAILED - {e}") total_time = time.time() - start_total if verbose: @@ -1399,7 +1656,7 @@ class CodegenRunner: result = future.result() results.append(result) if verbose: - status = "✓" if result.success else "✗" + status = "OK" if result.success else "FAIL" print( f" {status} {tile_str}: {result.kernel_count} kernels in {result.elapsed_seconds:.2f}s" ) @@ -1417,7 +1674,7 @@ class CodegenRunner: ) ) if verbose: - print(f" ✗ {tile_str}: FAILED - {e}") + print(f" FAIL {tile_str}: FAILED - {e}") total_time = time.time() - start_total if verbose: @@ -1481,7 +1738,7 @@ class CodegenRunner: result = future.result() results.append(result) if verbose: - status = "✓" if result.success else "✗" + status = "OK" if result.success else "FAIL" print( f" {status} {variant}: {result.kernel_count} kernels in {result.elapsed_seconds:.2f}s" ) @@ -1499,7 +1756,7 @@ class CodegenRunner: ) ) if verbose: - print(f" ✗ {variant}: FAILED - {e}") + print(f" FAIL {variant}: FAILED - {e}") total_time = time.time() - start_total if verbose: @@ -1767,7 +2024,7 @@ class CodegenRunner: link_cmd, capture_output=True, text=True, timeout=300 ) if result.returncode == 0: - print(f" ✓ Library rebuilt: {lib_path.name}") + print(f" OK Library rebuilt: {lib_path.name}") # Clean up object file obj_file.unlink(missing_ok=True) return lib_path @@ -1781,6 +2038,105 @@ class CodegenRunner: print(f" Build error: {e}") return None + def build_libraries_parallel( + self, configs_and_headers: List[Tuple[KernelConfig, Path]], verbose: bool = True + ) -> List[Optional[Path]]: + """ + Build multiple libraries in parallel using ProcessPoolExecutor. + Returns a list of library paths (or None if a build failed) in the same order. + """ + import time + from concurrent.futures import ProcessPoolExecutor, as_completed + + start_time = time.time() + build_dir = get_build_dir() + root = get_dispatcher_root() + ck_root = root.parent + ctypes_source = root / "bindings/ctypes/gemm_ctypes_lib.cpp" + static_lib = build_dir / "libck_tile_dispatcher.a" + + if not ctypes_source.exists() or not static_lib.exists(): + if verbose: + print(" Required source or static library missing for parallel build.") + return [None] * len(configs_and_headers) + + args_list = [] + for config, kernel_header in configs_and_headers: + lib_name = f"libdispatcher_gemm_{config.dtype_a}_{config.layout}_{config.tile_str}_{config.pipeline}.so" + lib_path = build_dir / "examples" / lib_name + obj_file = lib_path.with_suffix(".o") + + compile_cmd = [ + "/opt/rocm/bin/hipcc", + "-c", + "-fPIC", + "-O3", + f"-I{root / 'include'}", + f"-I{ck_root / 'include'}", + f"-I{ck_root}", + f"-I{root / 'build/generated_kernels'}", + "-DCK_TILE_SINGLE_KERNEL_INCLUDE", + f"-include{kernel_header}", + "-D__HIP_PLATFORM_AMD__", + f"--offload-arch={config.gfx_arch}", + f'-DGFX_ARCH="{config.gfx_arch}"', + "-mllvm", + "-enable-noalias-to-md-conversion=0", + "-Wno-undefined-func-template", + "-Wno-float-equal", + str(ctypes_source), + "-o", + str(obj_file), + ] + + link_cmd = [ + "/opt/rocm/bin/hipcc", + "-shared", + "-fPIC", + f"--offload-arch={config.gfx_arch}", + "--hip-link", + str(obj_file), + str(static_lib), + "-o", + str(lib_path), + ] + + args_list.append( + { + "compile_cmd": compile_cmd, + "link_cmd": link_cmd, + "lib_path": str(lib_path), + "config_name": f"{config.dtype_a}_{config.layout}_{config.tile_str}", + } + ) + + if verbose: + print( + f"Building {len(args_list)} libraries in parallel (workers={self.max_workers})..." + ) + + results_map = {} + with ProcessPoolExecutor(max_workers=self.max_workers) as executor: + futures = { + executor.submit(_run_hipcc_subprocess, args): i + for i, args in enumerate(args_list) + } + for future in as_completed(futures): + idx = futures[future] + success, lib_path, err = future.result() + results_map[idx] = Path(lib_path) if success else None + if verbose: + status = "OK" if success else f"FAIL ({err})" + print( + f" {status} {Path(lib_path).name if success else args_list[idx]['config_name']}" + ) + + if verbose: + elapsed = time.time() - start_time + print(f"Parallel build finished in {elapsed:.2f}s") + + return [results_map[i] for i in range(len(configs_and_headers))] + def generate_preselected( self, preset: str = "fp16_rcr_essential", output_dir: Optional[Path] = None ) -> CodegenResult: @@ -1933,6 +2289,28 @@ class Registry: """Bind to a loaded dispatcher library.""" self._lib = lib + def build( + self, + verbose: bool = False, + max_workers: Optional[int] = None, + ) -> List["GemmSetupResult"]: + """Parallel JIT compile all kernels in this registry. + + Args: + verbose: Print progress during build. + max_workers: Max parallel codegen/compile processes (default: cpu_count capped at 8). + + Returns a GemmSetupResult per registered kernel (same order as get_kernels()). + """ + if not self._kernels: + return [] + return setup_multiple_gemm_dispatchers( + self._kernels, + registry_name=self._name, + verbose=verbose, + max_workers=max_workers, + ) + def __repr__(self) -> str: return f"Registry(name='{self._name}', kernels={self.kernel_count})" @@ -2109,7 +2487,7 @@ def setup_gemm_dispatcher( log(" Validating config...") validation = validate_kernel_config(config) if not validation.is_valid: - log(" ⚠ Auto-correcting configuration...") + log(" WARNING Auto-correcting configuration...") config, was_modified, corrections = auto_correct_kernel_config( config, verbose=verbose ) @@ -2128,13 +2506,13 @@ def setup_gemm_dispatcher( codegen_result = codegen.generate_from_config(config) if not codegen_result.success: - log(" ⚠ Kernel generation: using existing") + log(" WARNING Kernel generation: using existing") # Step 3: Find matching kernel header kernel_header = find_matching_kernel_header(config) result.kernel_header = kernel_header if not kernel_header: - log(" ⚠ No matching kernel header found") + log(" WARNING No matching kernel header found") # Step 4: Load library log(" Loading library...") @@ -2188,11 +2566,11 @@ def setup_gemm_dispatcher( result.error = "Failed to load rebuilt library" return result result.lib = lib - log(f" ✓ Rebuilt library: {lib.get_kernel_name()}") + log(f" OK Rebuilt library: {lib.get_kernel_name()}") else: - log(" ⚠ Rebuild failed, using existing library") + log(" WARNING Rebuild failed, using existing library") else: - log(" ⚠ No kernel header found for config, using existing library") + log(" WARNING No kernel header found for config, using existing library") # Step 5: Create registry and dispatcher log(" Creating registry and dispatcher...") @@ -2203,12 +2581,305 @@ def setup_gemm_dispatcher( dispatcher = Dispatcher(registry=registry, lib=lib) result.dispatcher = dispatcher - log(f" ✓ Ready: {lib.get_kernel_name()}") + log(f" OK Ready: {lib.get_kernel_name()}") result.success = True return result +def setup_multiple_gemm_dispatchers( + configs: List[KernelConfig], + registry_name: str = "gemm_registry", + verbose: bool = True, + max_workers: Optional[int] = None, +) -> List[GemmSetupResult]: + """ + Setup multiple GEMM dispatchers in parallel. + + Pipeline: + 1. Validate + auto-correct each config + 2. Parallel codegen: generate .hpp for each config via --config JSON + 3. Parallel hipcc: compile each .hpp -> .so + 4. Load + wire up each .so into a GemmSetupResult + + Each config gets its own .so, so different tile sizes can coexist. + + Args: + max_workers: Max parallel processes for codegen/compile (default: cpu_count capped at 8). + """ + import sys + + results = [GemmSetupResult(success=False, config=c) for c in configs] + max_workers = max_workers or min(multiprocessing.cpu_count(), 8) + + # -- Step 1: Validate & correct --------------------------------------- + valid_configs = [] + for i, c in enumerate(configs): + val = validate_kernel_config(c) + if not val.is_valid: + c, modified, corrections = auto_correct_kernel_config(c, verbose=False) + results[i].config = c + results[i].corrections = corrections + valid_configs.append(c) + + # -- Step 2: Parallel codegen (one --config JSON per config) ---------- + codegen_script = get_codegen_path() + output_dir = get_generated_kernels_dir() + + codegen_args = [] + for c in valid_configs: + tile_str = c.tile_str + wave_str = f"{c.wave_m}x{c.wave_n}x{c.wave_k}" + warp_str = f"{c.warp_m}x{c.warp_n}x{c.warp_k}" + + tile_config_json = { + "tile_config": { + "tile_m": [c.tile_m], + "tile_n": [c.tile_n], + "tile_k": [c.tile_k], + "warp_m": [c.wave_m], + "warp_n": [c.wave_n], + "warp_k": [c.wave_k], + "warp_tile_m": [c.warp_m], + "warp_tile_n": [c.warp_n], + "warp_tile_k": [c.warp_k], + }, + "trait_config": { + "pipeline": [c.pipeline], + "epilogue": [c.epilogue], + "scheduler": [c.scheduler], + "pad_m": [c.pad_m], + "pad_n": [c.pad_n], + "pad_k": [c.pad_k], + "persistent": [False], + }, + } + + hpp_pattern = ( + f"gemm_{c.dtype_a}_{c.layout}_{c.pipeline}_{c.epilogue}_{c.scheduler}" + f"_*_{tile_str}_{wave_str}_{warp_str}.hpp" + ) + + codegen_args.append( + { + "python": sys.executable, + "codegen_script": str(codegen_script), + "output_dir": str(output_dir), + "dtype": c.dtype_a, + "layout": c.layout, + "gpu_target": c.gfx_arch, + "tile_config_json": tile_config_json, + "hpp_glob_pattern": hpp_pattern, + } + ) + + if verbose: + print( + f"Generating {len(codegen_args)} kernel headers in parallel (workers={max_workers})..." + ) + + headers: List[Optional[Path]] = [None] * len(valid_configs) + with ProcessPoolExecutor(max_workers=max_workers) as executor: + futures = { + executor.submit(_generate_single_kernel_subprocess, a): i + for i, a in enumerate(codegen_args) + } + for future in as_completed(futures): + idx = futures[future] + ok, hdr_str, err = future.result() + if ok and hdr_str: + headers[idx] = Path(hdr_str) + results[idx].kernel_header = Path(hdr_str) + if verbose: + print( + f" OK [{idx}] {valid_configs[idx].tile_str}: {Path(hdr_str).name}" + ) + else: + results[idx].error = f"Codegen: {err}" + if verbose: + print(f" FAIL [{idx}] {valid_configs[idx].tile_str}: {err}") + + # For configs rejected by arch filter, map to nearest arch-valid header. + fallback_needed = [i for i, h in enumerate(headers) if h is None] + if fallback_needed: + if verbose: + print( + f"Resolving {len(fallback_needed)} configs via arch-valid GEMM catalog..." + ) + + catalog_cache: Dict[Tuple[str, str, str, str], List[Path]] = {} + for i in fallback_needed: + c = valid_configs[i] + key = (c.gfx_arch, c.dtype_a, c.layout, c.variant) + if key not in catalog_cache: + catalog_dir = ( + output_dir + / "_arch_valid_catalog" + / (f"{c.gfx_arch}_{c.dtype_a}_{c.layout}_{c.variant}") + ) + ok, catalog_headers, err = _generate_arch_valid_gemm_headers( + python_exe=sys.executable, + codegen_script=codegen_script, + output_dir=catalog_dir, + dtype=c.dtype_a, + layout=c.layout, + gpu_target=c.gfx_arch, + variant=c.variant, + ) + if not ok: + catalog_headers = [] + if verbose: + print(f" FAIL [{i}] catalog generation: {err}") + catalog_cache[key] = catalog_headers + + chosen, meta = _select_best_arch_valid_gemm_header(c, catalog_cache[key]) + if chosen is None or meta is None: + continue + + headers[i] = chosen + results[i].kernel_header = chosen + results[i].error = "" + + # Keep Python-side config aligned with the selected kernel header. + valid_configs[i].pipeline = str(meta["pipeline"]) + valid_configs[i].epilogue = str(meta["epilogue"]) + valid_configs[i].scheduler = str(meta["scheduler"]) + valid_configs[i].pad_m = bool(meta["pad_m"]) + valid_configs[i].pad_n = bool(meta["pad_n"]) + valid_configs[i].pad_k = bool(meta["pad_k"]) + valid_configs[i].tile_m = int(meta["tile"][0]) + valid_configs[i].tile_n = int(meta["tile"][1]) + valid_configs[i].tile_k = int(meta["tile"][2]) + valid_configs[i].wave_m = int(meta["wave"][0]) + valid_configs[i].wave_n = int(meta["wave"][1]) + valid_configs[i].wave_k = int(meta["wave"][2]) + valid_configs[i].warp_m = int(meta["warp"][0]) + valid_configs[i].warp_n = int(meta["warp"][1]) + valid_configs[i].warp_k = int(meta["warp"][2]) + results[i].config = valid_configs[i] + + if verbose: + print(f" INFO [{i}] mapped to arch-valid header: {chosen.name}") + + # -- Step 3: Parallel hipcc compilation ------------------------------- + root = get_dispatcher_root() + ck_root = root.parent + build_dir = get_build_dir() + ctypes_source = root / "bindings" / "ctypes" / "gemm_ctypes_lib.cpp" + static_lib = build_dir / "libck_tile_dispatcher.a" + + if not ctypes_source.exists() or not static_lib.exists(): + for i in range(len(valid_configs)): + if results[i].error == "": + results[ + i + ].error = "Missing ctypes source or static library for compilation" + return results + + compile_jobs = [] + compile_index_map = {} + for i, c in enumerate(valid_configs): + hdr = headers[i] + if hdr is None: + continue + + lib_name = ( + f"libdispatcher_gemm_{c.dtype_a}_{c.layout}_{c.tile_str}_{c.pipeline}.so" + ) + lib_path = build_dir / "examples" / lib_name + obj_file = lib_path.with_suffix(".o") + + compile_cmd = [ + "/opt/rocm/bin/hipcc", + "-c", + "-fPIC", + "-O3", + f"-I{root / 'include'}", + f"-I{ck_root / 'include'}", + f"-I{ck_root}", + f"-I{str(output_dir)}", + "-DCK_TILE_SINGLE_KERNEL_INCLUDE", + f"-include{hdr}", + "-D__HIP_PLATFORM_AMD__", + f"--offload-arch={c.gfx_arch}", + f'-DGFX_ARCH="{c.gfx_arch}"', + "-mllvm", + "-enable-noalias-to-md-conversion=0", + "-Wno-undefined-func-template", + "-Wno-float-equal", + str(ctypes_source), + "-o", + str(obj_file), + ] + link_cmd = [ + "/opt/rocm/bin/hipcc", + "-shared", + "-fPIC", + f"--offload-arch={c.gfx_arch}", + "--hip-link", + str(obj_file), + str(static_lib), + "-o", + str(lib_path), + ] + + compile_index_map[len(compile_jobs)] = i + compile_jobs.append( + { + "compile_cmd": compile_cmd, + "link_cmd": link_cmd, + "lib_path": str(lib_path), + } + ) + + if verbose and compile_jobs: + print( + f"Compiling {len(compile_jobs)} libraries in parallel (workers={max_workers})..." + ) + + lib_paths: Dict[int, Optional[Path]] = {} + with ProcessPoolExecutor(max_workers=max_workers) as executor: + futures = { + executor.submit(_run_hipcc_subprocess, job): j + for j, job in enumerate(compile_jobs) + } + for future in as_completed(futures): + j = futures[future] + i = compile_index_map[j] + ok, lp, err = future.result() + if ok and lp: + lib_paths[i] = Path(lp) + if verbose: + print(f" OK [{i}] {valid_configs[i].tile_str}: {Path(lp).name}") + else: + results[i].error = f"Compile: {err}" + if verbose: + print(f" FAIL [{i}] {valid_configs[i].tile_str}: {err}") + + # -- Step 4: Load libraries and create dispatchers -------------------- + for i, c in enumerate(valid_configs): + lp = lib_paths.get(i) + if lp is None: + continue + + lib = DispatcherLib.load(lp) + if lib is not None and lib.initialize(): + results[i].lib = lib + reg = Registry(name=f"{registry_name}_{i}", lib=lib) + reg.register_kernel(c) + results[i].registry = reg + results[i].dispatcher = Dispatcher(registry=reg, lib=lib) + results[i].success = True + else: + results[i].error = "Failed to load compiled library" + + if verbose: + ok_count = sum(1 for r in results if r.success) + print(f"Setup complete: {ok_count}/{len(results)} dispatchers ready") + + return results + + def cleanup_gemm(): """ Cleanup function to call after running GEMM examples. diff --git a/dispatcher/python/dispatcher_common.py b/dispatcher/python/dispatcher_common.py new file mode 100644 index 0000000000..a19ecbdb49 --- /dev/null +++ b/dispatcher/python/dispatcher_common.py @@ -0,0 +1,372 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Shared Python dispatcher utilities for GEMM and grouped convolution. + +Extracted from ctypes_utils.py (GEMM) + compile_grouped_conv_examples.py (grouped conv). +Both ctypes_utils.py and grouped_conv_utils.py import from here to +eliminate duplication. + +Best-of-both: + - Validation and auto-correction return typed objects (GEMM pattern) + - Colors class with cross-platform ANSI handling (conv pattern) + - Phased output helpers (conv pattern) + - logging module instead of bare print() (shared improvement) +""" + +import logging +import shutil +import sys +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +log = logging.getLogger(__name__) + + +# ============================================================================ +# Path Configuration +# ============================================================================ + + +def get_dispatcher_root() -> Path: + """Get the dispatcher root directory (parent of python/).""" + return Path(__file__).parent.parent + + +def get_ck_root() -> Path: + """Get the CK root directory (parent of dispatcher/).""" + return get_dispatcher_root().parent + + +def get_build_dir() -> Path: + """Get the build directory.""" + return get_dispatcher_root() / "build" + + +def get_generated_kernels_dir() -> Path: + """Get the generated kernels directory.""" + return get_build_dir() / "generated_kernels" + + +def get_codegen_dir() -> Path: + """Get the codegen scripts directory.""" + return get_dispatcher_root() / "codegen" + + +# ============================================================================ +# Architecture Filter Data +# ============================================================================ + +_arch_data_cache: Optional[Dict[str, Any]] = None + + +def detect_gpu_arch(fallback: str = "gfx942") -> str: + """Detect the GPU architecture from rocminfo. Falls back to the given default.""" + import subprocess + + try: + out = subprocess.check_output( + ["rocminfo"], text=True, stderr=subprocess.DEVNULL + ) + for line in out.splitlines(): + if "Name:" in line and "gfx" in line: + return line.split()[-1].strip() + except Exception: + pass + return fallback + + +def get_arch_filter_data() -> Dict[str, Any]: + """Load arch filter data from arch_specs_generated if available. + + Returns dict with keys: trait_unsupported, warp_combos, + warp_tile_combos, supported_archs. + """ + global _arch_data_cache + if _arch_data_cache is not None: + return _arch_data_cache + + codegen_dir = get_dispatcher_root() / "codegen" + sys.path.insert(0, str(codegen_dir)) + + try: + from arch_specs_generated import ( + TRAIT_UNSUPPORTED_COMBINATIONS, + WARP_SUPPORTED_COMBINATIONS, + WARP_TILE_SUPPORTED_COMBINATIONS, + get_supported_archs, + ) + + _arch_data_cache = { + "trait_unsupported": TRAIT_UNSUPPORTED_COMBINATIONS, + "warp_combos": WARP_SUPPORTED_COMBINATIONS, + "warp_tile_combos": WARP_TILE_SUPPORTED_COMBINATIONS, + "supported_archs": get_supported_archs(), + } + except ImportError: + _arch_data_cache = { + "trait_unsupported": { + ("compv3", "cshuffle", "interwave"), + ("compv3", "default", "interwave"), + ("compv4", "cshuffle", "interwave"), + ("compv4", "default", "interwave"), + }, + "warp_combos": { + "gfx942": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + "gfx90a": [[1, 4, 1], [2, 2, 1], [4, 1, 1]], + }, + "warp_tile_combos": { + "gfx942": {"fp16_fp16_fp32": [[16, 16, 16], [32, 32, 16]]}, + "gfx90a": {"fp16_fp16_fp32": [[16, 16, 16], [32, 32, 16]]}, + }, + "supported_archs": ["gfx90a", "gfx942", "gfx950"], + } + + return _arch_data_cache + + +# ============================================================================ +# Validation Result +# ============================================================================ + + +@dataclass +class ValidationResultBase: + """Result of kernel config validation (shared base for GEMM and conv).""" + + is_valid: bool + errors: List[str] = field(default_factory=list) + warnings: List[str] = field(default_factory=list) + suggested_fixes: Dict[str, Any] = field(default_factory=dict) + + def print_result(self, indent: str = " "): + if self.is_valid: + print(f"{indent}OK Configuration valid") + else: + print(f"{indent}WARNING Configuration has issues:") + for err in self.errors: + print(f"{indent} - {err}") + if self.warnings: + for warn in self.warnings: + print(f"{indent} Warning: {warn}") + if self.suggested_fixes: + print(f"{indent} Suggested fixes:") + for key, val in self.suggested_fixes.items(): + print(f"{indent} {key}: {val}") + + +# ============================================================================ +# Validation Helpers +# ============================================================================ + + +def validate_wave_config(wave_cfg: List[int], arch: str) -> Tuple[bool, str]: + """Validate a [wave_m, wave_n, wave_k] config for *arch*. + + Returns (is_valid, error_message). Empty string on success. + """ + data = get_arch_filter_data() + valid_waves = data["warp_combos"].get(arch, [[2, 2, 1]]) + if wave_cfg in valid_waves: + return True, "" + valid_str = ", ".join(f"[{c[0]},{c[1]},{c[2]}]" for c in valid_waves) + return ( + False, + f"Unsupported wave configuration {wave_cfg} for {arch}. " + f"Valid wave configs: {valid_str}", + ) + + +def validate_warp_tile_config( + warp_cfg: List[int], arch: str, dtype: str +) -> Tuple[bool, str]: + """Validate a [warp_m, warp_n, warp_k] config for *arch*/*dtype*. + + Returns (is_valid, error_message). Empty string on success. + """ + data = get_arch_filter_data() + acc = "int32" if dtype == "int8" else "fp32" + dtype_key = f"{dtype}_{dtype}_{acc}" + valid_tiles = ( + data["warp_tile_combos"] + .get(arch, {}) + .get(dtype_key, [[32, 32, 16], [16, 16, 16]]) + ) + if warp_cfg in valid_tiles: + return True, "" + valid_str = ", ".join(f"[{c[0]},{c[1]},{c[2]}]" for c in valid_tiles[:5]) + return ( + False, + f"Unsupported warp tile {warp_cfg} for {arch}/{dtype}. " + f"Valid warp tiles: {valid_str}", + ) + + +def validate_trait_combo( + pipeline: str, epilogue: str, scheduler: str +) -> Tuple[bool, str]: + """Validate a (pipeline, epilogue, scheduler) combination. + + Returns (is_valid, error_message). Empty string on success. + """ + data = get_arch_filter_data() + combo = (pipeline, epilogue, scheduler) + if combo in data["trait_unsupported"]: + return ( + False, + f"Unsupported trait combination: pipeline={pipeline}, " + f"epilogue={epilogue}, scheduler={scheduler}", + ) + return True, "" + + +# ============================================================================ +# Auto-Correction Helpers +# ============================================================================ + + +def auto_correct_wave(wave_cfg: List[int], arch: str) -> List[int]: + """Return the first valid wave config for *arch*. + + If *wave_cfg* is already valid, returns it unchanged. + """ + data = get_arch_filter_data() + valid_waves = data["warp_combos"].get(arch, [[2, 2, 1]]) + if wave_cfg in valid_waves: + return wave_cfg + return valid_waves[0] if valid_waves else [2, 2, 1] + + +def auto_correct_trait(pipeline: str, scheduler: str) -> Tuple[str, str]: + """Return a corrected (pipeline, scheduler) pair. + + If the compute pipeline doesn't support interwave, switch to intrawave. + """ + data = get_arch_filter_data() + for epilogue in ("cshuffle", "default"): + if (pipeline, epilogue, scheduler) in data["trait_unsupported"]: + return pipeline, "intrawave" + return pipeline, scheduler + + +# ============================================================================ +# Colors (adopted from compile_grouped_conv_examples.py -- cross-platform) +# ============================================================================ + + +class Colors: + """Cross-platform ANSI color support. + + Respects sys.platform (no ANSI on Windows) and isatty() check so + piped/redirected output stays clean. + """ + + _GREEN = "\033[0;32m" + _YELLOW = "\033[1;33m" + _RED = "\033[0;31m" + _CYAN = "\033[0;36m" + _BOLD = "\033[1m" + _NC = "\033[0m" + + @classmethod + def _use_color(cls) -> bool: + return ( + sys.platform != "win32" + and hasattr(sys.stdout, "isatty") + and sys.stdout.isatty() + ) + + @classmethod + def green(cls, text: str) -> str: + if cls._use_color(): + return f"{cls._GREEN}{text}{cls._NC}" + return text + + @classmethod + def red(cls, text: str) -> str: + if cls._use_color(): + return f"{cls._RED}{text}{cls._NC}" + return text + + @classmethod + def yellow(cls, text: str) -> str: + if cls._use_color(): + return f"{cls._YELLOW}{text}{cls._NC}" + return text + + @classmethod + def cyan(cls, text: str) -> str: + if cls._use_color(): + return f"{cls._CYAN}{text}{cls._NC}" + return text + + @classmethod + def bold(cls, text: str) -> str: + if cls._use_color(): + return f"{cls._BOLD}{text}{cls._NC}" + return text + + +# ============================================================================ +# Phased Output Helpers +# ============================================================================ + + +def print_phase(number: int, description: str) -> None: + """Print a phase header (e.g. 'Phase 1: Codegen').""" + print(f"\n{'=' * 60}") + print(f" Phase {number}: {description}") + print(f"{'=' * 60}") + + +def print_success(message: str) -> None: + """Print a success message.""" + print(f" OK {Colors.green(message)}") + + +def print_error(message: str) -> None: + """Print an error message.""" + print(f" FAIL {Colors.red(message)}") + + +def print_info(message: str) -> None: + """Print an info message.""" + print(f" {Colors.cyan(message)}") + + +# ============================================================================ +# Cleanup Helpers +# ============================================================================ + + +def cleanup_generated_kernels(gen_dir: Optional[Path] = None) -> None: + """Remove generated kernel directory if it exists.""" + if gen_dir is None: + gen_dir = get_generated_kernels_dir() + if gen_dir.exists(): + shutil.rmtree(gen_dir, ignore_errors=True) + log.info("Cleaned up generated kernels at %s", gen_dir) + + +# ============================================================================ +# Tool Helpers +# ============================================================================ + + +def find_hipcc() -> Optional[str]: + """Find the hipcc compiler.""" + import os + + candidates = [ + os.environ.get("HIPCC"), + "/opt/rocm/bin/hipcc", + shutil.which("hipcc"), + ] + for path in candidates: + if path and os.path.isfile(path): + return path + return None diff --git a/dispatcher/python/grouped_conv_utils.py b/dispatcher/python/grouped_conv_utils.py new file mode 100644 index 0000000000..cd6ef5647c --- /dev/null +++ b/dispatcher/python/grouped_conv_utils.py @@ -0,0 +1,1806 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Grouped Convolution Dispatcher Utilities + +Typed Python API for grouped convolution kernels, matching the patterns from +the old conv_utils.py and the GEMM ctypes_utils.py. + +Classes: + GroupedConvKernelConfig - Kernel configuration (tile, wave, pipeline, arch) + GroupedConvProblem - Runtime problem specification (N,C,K,H,W,etc.) + GroupedConvProblemC - ctypes struct matching C++ ConvProblemC + GroupedConvDispatcherLib - Wrapper for libdispatcher_conv_lib.so + GpuGroupedConvRunner - High-level GPU execution runner + GroupedConvResult - Result of GPU execution (output, time, tflops) + GroupedConvRegistry - Collection of kernel configs with JSON export + +Usage: + from grouped_conv_utils import ( + GroupedConvKernelConfig, + GroupedConvProblem, + GpuGroupedConvRunner, + ) + + config = GroupedConvKernelConfig(variant="forward", ndim_spatial=2) + problem = GroupedConvProblem(N=1, C=64, K=128, Hi=28, Wi=28, Y=3, X=3, + stride_h=1, pad_h=1, direction="forward") + runner = GpuGroupedConvRunner() + if runner.is_available(): + result = runner.run(input_np, weight_np, problem) + print(f"Time: {result.time_ms:.4f} ms, TFLOPS: {result.tflops:.2f}") +""" + +import ctypes +import json +import copy +import subprocess +from dataclasses import dataclass +from enum import Enum +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np + +from dispatcher_common import ( + ValidationResultBase, + auto_correct_trait, + auto_correct_wave, + get_arch_filter_data, + validate_trait_combo, + validate_wave_config, + validate_warp_tile_config, +) + + +# ============================================================================= +# Constants +# ============================================================================= + +VALID_VARIANTS = ("forward", "bwd_data", "bwd_weight") +VALID_NDIM_SPATIAL = (1, 2, 3) +BACKWARD_VARIANTS = ("bwd_data", "bwd_weight") +BACKWARD_PIPELINES = ("compv3", "mem") + +VARIANT_ALIASES = { + "2d_fwd": "forward", + "2d_bwdd": "bwd_data", + "2d_bwdw": "bwd_weight", + "fwd": "forward", + "bwdd": "bwd_data", + "bwdw": "bwd_weight", +} + +DIRECTION_MAP = {"forward": 0, "bwd_data": 1, "bwd_weight": 2} + + +def _resolve_variant(v: str) -> str: + return VARIANT_ALIASES.get(v, v) + + +# ============================================================================= +# GroupedConvDataType +# ============================================================================= + + +class GroupedConvDataType(Enum): + FP16 = "fp16" + BF16 = "bf16" + FP32 = "fp32" + FP8 = "fp8" + BF8 = "bf8" + INT8 = "int8" + + +# ============================================================================= +# GroupedConvKernelConfig +# ============================================================================= + + +@dataclass +class GroupedConvKernelConfig: + """Complete kernel configuration for grouped convolution. + + Captures all parameters needed to identify and run a specific kernel. + Mirrors the C++ GroupedConvSignature + GroupedConvAlgorithm. + """ + + # What: signature + variant: str = "forward" + ndim_spatial: int = 2 + dtype: str = "fp16" + layout: str = "nhwgc" + arch: str = "gfx942" + + # How: algorithm - tile shape + tile_m: int = 1 + tile_n: int = 128 + tile_k: int = 128 + + # How: wave config + wave_m: int = 2 + wave_n: int = 2 + wave_k: int = 1 + + # How: warp tile + warp_tile_m: int = 32 + warp_tile_n: int = 32 + warp_tile_k: int = 16 + + # How: pipeline traits + pipeline: str = "compv4" + epilogue: str = "cshuffle" + scheduler: str = "intrawave" + + # ConvConfigBase parity fields + vector_size_a: int = 4 + vector_size_b: int = 8 + vector_size_c: int = 8 + block_per_cu: int = 1 + num_wave_groups: int = 1 + num_groups_to_merge: int = 1 + + # Padding (enables arbitrary problem sizes) + pad_m: bool = True + pad_n: bool = True + pad_k: bool = True + + def __post_init__(self): + self.variant = _resolve_variant(self.variant) + if ( + self.variant in BACKWARD_VARIANTS + and self.pipeline not in BACKWARD_PIPELINES + ): + self.pipeline = "compv3" + + @property + def tile_str(self) -> str: + return f"{self.tile_m}x{self.tile_n}x{self.tile_k}" + + @property + def wave_str(self) -> str: + return f"{self.wave_m}x{self.wave_n}x{self.wave_k}" + + @property + def warp_str(self) -> str: + return f"{self.warp_tile_m}x{self.warp_tile_n}x{self.warp_tile_k}" + + @property + def vec_str(self) -> str: + return f"{self.vector_size_a}x{self.vector_size_b}x{self.vector_size_c}" + + @property + def name(self) -> str: + return ( + f"grouped_conv_{self.variant}_{self.dtype}_{self.ndim_spatial}d_" + f"{self.tile_str}_{self.pipeline}" + ) + + def to_dict(self) -> dict: + """Convert to legacy dict format for codegen compatibility.""" + return { + "tile_config": { + "tile_m": [self.tile_m], + "tile_n": [self.tile_n], + "tile_k": [self.tile_k], + "wave_m": [self.wave_m], + "wave_n": [self.wave_n], + "wave_k": [self.wave_k], + "warp_tile_m": [self.warp_tile_m], + "warp_tile_n": [self.warp_tile_n], + "warp_tile_k": [self.warp_tile_k], + }, + "trait_config": { + "pipeline": [self.pipeline], + "epilogue": [self.epilogue], + "scheduler": [self.scheduler], + "pad_m": [self.pad_m], + "pad_n": [self.pad_n], + "pad_k": [self.pad_k], + "vector_size_a": [self.vector_size_a], + "vector_size_b": [self.vector_size_b], + "vector_size_c": [self.vector_size_c], + "block_per_cu": [self.block_per_cu], + "num_wave_groups": [self.num_wave_groups], + "num_groups_to_merge": [self.num_groups_to_merge], + }, + "variant": self.variant, + "ndim_spatial": self.ndim_spatial, + "arch": self.arch, + "layout": self.layout, + "dtype": self.dtype, + } + + def to_json_obj(self) -> dict: + """Serializable dict for JSON export.""" + return { + "name": self.name, + "signature": { + "variant": self.variant, + "dtype": self.dtype, + "ndim_spatial": self.ndim_spatial, + "layout": self.layout, + }, + "algorithm": { + "tile_m": self.tile_m, + "tile_n": self.tile_n, + "tile_k": self.tile_k, + "wave": self.wave_str, + "warp": self.warp_str, + "pipeline": self.pipeline, + "epilogue": self.epilogue, + "scheduler": self.scheduler, + "vector_sizes": [ + self.vector_size_a, + self.vector_size_b, + self.vector_size_c, + ], + "block_per_cu": self.block_per_cu, + "num_wave_groups": self.num_wave_groups, + "num_groups_to_merge": self.num_groups_to_merge, + }, + "arch": self.arch, + } + + def print_config(self, indent: str = " "): + print(f"{indent}GroupedConvKernelConfig:") + print(f"{indent} Variant: {self.variant} {self.ndim_spatial}D") + print(f"{indent} Dtype: {self.dtype}") + print(f"{indent} Layout: {self.layout}") + print(f"{indent} Arch: {self.arch}") + print(f"{indent} Tile: {self.tile_str}") + print(f"{indent} Wave: {self.wave_str}") + print(f"{indent} Warp: {self.warp_str}") + print(f"{indent} Pipeline: {self.pipeline}/{self.scheduler}/{self.epilogue}") + print(f"{indent} VecSizes: {self.vec_str}") + print( + f"{indent} BlockCU: {self.block_per_cu} WaveGroups: {self.num_wave_groups} MergeGroups: {self.num_groups_to_merge}" + ) + + +# ============================================================================= +# GroupedConvProblem +# ============================================================================= + + +@dataclass +class GroupedConvProblem: + """Runtime convolution problem specification. + + Describes the actual sizes of a convolution to be computed. + Matches the old ConvProblem from conv_utils.py. + """ + + N: int = 1 + C: int = 64 + K: int = 128 + G: int = 1 + + Hi: int = 28 + Wi: int = 28 + Di: int = 1 + + Y: int = 3 + X: int = 3 + Z: int = 1 + + stride_h: int = 1 + stride_w: int = 1 + stride_d: int = 1 + + pad_h: int = 0 + pad_w: int = 0 + pad_d: int = 0 + + dilation_h: int = 1 + dilation_w: int = 1 + dilation_d: int = 1 + + direction: str = "forward" + split_k: int = 1 + + @property + def Ho(self) -> int: + eff_y = (self.Y - 1) * self.dilation_h + 1 + return (self.Hi + 2 * self.pad_h - eff_y) // self.stride_h + 1 + + @property + def Wo(self) -> int: + eff_x = (self.X - 1) * self.dilation_w + 1 + return (self.Wi + 2 * self.pad_w - eff_x) // self.stride_w + 1 + + @property + def Do(self) -> int: + eff_z = (self.Z - 1) * self.dilation_d + 1 + return (self.Di + 2 * self.pad_d - eff_z) // self.stride_d + 1 + + @property + def is_3d(self) -> bool: + return self.Di > 1 or self.Z > 1 or self.pad_d > 0 + + @property + def ndim_spatial(self) -> int: + return 3 if self.is_3d else 2 + + @property + def flops(self) -> float: + """Total FLOPs for this convolution (any direction, same count).""" + c_per_group = self.C // self.G + if self.is_3d: + return ( + 2.0 + * self.N + * self.K + * self.Do + * self.Ho + * self.Wo + * c_per_group + * self.Z + * self.Y + * self.X + ) + return 2.0 * self.N * self.K * self.Ho * self.Wo * c_per_group * self.Y * self.X + + @property + def gflops(self) -> float: + return self.flops / 1e9 + + def input_shape(self) -> tuple: + """NHWGC or NDHWGC layout.""" + c_per_g = self.C // self.G + if self.is_3d: + return (self.N, self.Di, self.Hi, self.Wi, self.G, c_per_g) + return (self.N, self.Hi, self.Wi, self.G, c_per_g) + + def weight_shape(self) -> tuple: + """GKYXC or GKZYXC layout.""" + c_per_g = self.C // self.G + k_per_g = self.K // self.G + if self.is_3d: + return (self.G, k_per_g, self.Z, self.Y, self.X, c_per_g) + return (self.G, k_per_g, self.Y, self.X, c_per_g) + + def output_shape(self) -> tuple: + """NHWGK or NDHWGK layout.""" + k_per_g = self.K // self.G + if self.is_3d: + return (self.N, self.Do, self.Ho, self.Wo, self.G, k_per_g) + return (self.N, self.Ho, self.Wo, self.G, k_per_g) + + def print_problem(self, indent: str = " "): + dim_str = "3D" if self.is_3d else "2D" + print(f"{indent}GroupedConvProblem ({dim_str} {self.direction}):") + print(f"{indent} Batch: N={self.N}, G={self.G}") + print(f"{indent} Channels: C={self.C}, K={self.K}") + if self.is_3d: + print(f"{indent} Input: Di={self.Di}, Hi={self.Hi}, Wi={self.Wi}") + print(f"{indent} Filter: Z={self.Z}, Y={self.Y}, X={self.X}") + print(f"{indent} Output: Do={self.Do}, Ho={self.Ho}, Wo={self.Wo}") + else: + print(f"{indent} Input: Hi={self.Hi}, Wi={self.Wi}") + print(f"{indent} Filter: Y={self.Y}, X={self.X}") + print(f"{indent} Output: Ho={self.Ho}, Wo={self.Wo}") + print(f"{indent} GFLOPs: {self.gflops:.2f}") + + +# ============================================================================= +# GroupedConvProblemC (ctypes struct matching C++) +# ============================================================================= + + +class GroupedConvProblemC(ctypes.Structure): + """C structure matching ConvProblemC in conv_ctypes_lib.cpp.""" + + _fields_ = [ + ("N", ctypes.c_int), + ("G", ctypes.c_int), + ("C", ctypes.c_int), + ("K", ctypes.c_int), + ("input_d", ctypes.c_int), + ("input_h", ctypes.c_int), + ("input_w", ctypes.c_int), + ("filter_z", ctypes.c_int), + ("filter_y", ctypes.c_int), + ("filter_x", ctypes.c_int), + ("stride_d", ctypes.c_int), + ("stride_h", ctypes.c_int), + ("stride_w", ctypes.c_int), + ("pad_d", ctypes.c_int), + ("pad_h", ctypes.c_int), + ("pad_w", ctypes.c_int), + ("dilation_d", ctypes.c_int), + ("dilation_h", ctypes.c_int), + ("dilation_w", ctypes.c_int), + ("direction", ctypes.c_int), + ("split_k", ctypes.c_int), + ] + + @classmethod + def from_problem(cls, p: GroupedConvProblem) -> "GroupedConvProblemC": + c = cls() + c.N, c.G, c.C, c.K = p.N, p.G, p.C, p.K + c.input_d, c.input_h, c.input_w = p.Di, p.Hi, p.Wi + c.filter_z, c.filter_y, c.filter_x = p.Z, p.Y, p.X + c.stride_d, c.stride_h, c.stride_w = p.stride_d, p.stride_h, p.stride_w + c.pad_d, c.pad_h, c.pad_w = p.pad_d, p.pad_h, p.pad_w + c.dilation_d, c.dilation_h, c.dilation_w = ( + p.dilation_d, + p.dilation_h, + p.dilation_w, + ) + c.direction = DIRECTION_MAP.get(p.direction, 0) + c.split_k = getattr(p, "split_k", 1) + return c + + +# ============================================================================= +# GroupedConvResult +# ============================================================================= + + +@dataclass +class GroupedConvResult: + """Result of GPU convolution execution.""" + + success: bool = False + time_ms: float = 0.0 + tflops: float = 0.0 + output: Optional[np.ndarray] = None + error: str = "" + + +# ============================================================================= +# GroupedConvDispatcherLib +# ============================================================================= + + +class GroupedConvDispatcherLib: + """Wrapper for the compiled convolution dispatcher library. + + Provides Python interface to the C API in conv_ctypes_lib.cpp. + """ + + SEARCH_PATHS = [ + "build/examples/libdispatcher_conv_lib.so", + "build/bindings/libdispatcher_conv_lib.so", + "build/lib/libdispatcher_conv_lib.so", + ] + + def __init__(self, lib: ctypes.CDLL, path: Path): + self._lib = lib + self._path = path + self._setup_functions() + + def _setup_functions(self): + self._lib.conv_dispatcher_init.argtypes = [] + self._lib.conv_dispatcher_init.restype = ctypes.c_int + self._lib.conv_dispatcher_cleanup.argtypes = [] + self._lib.conv_dispatcher_cleanup.restype = ctypes.c_int + self._lib.conv_dispatcher_version.argtypes = [] + self._lib.conv_dispatcher_version.restype = ctypes.c_char_p + self._lib.conv_dispatcher_has_kernels.argtypes = [] + self._lib.conv_dispatcher_has_kernels.restype = ctypes.c_int + self._lib.conv_dispatcher_has_bwd_data.argtypes = [] + self._lib.conv_dispatcher_has_bwd_data.restype = ctypes.c_int + self._lib.conv_dispatcher_has_bwd_weight.argtypes = [] + self._lib.conv_dispatcher_has_bwd_weight.restype = ctypes.c_int + self._lib.conv_dispatcher_get_kernel_count.argtypes = [] + self._lib.conv_dispatcher_get_kernel_count.restype = ctypes.c_int + self._lib.conv_dispatcher_get_kernel_name.argtypes = [ + ctypes.c_int, + ctypes.c_char_p, + ctypes.c_int, + ] + self._lib.conv_dispatcher_get_kernel_name.restype = ctypes.c_int + self._lib.conv_dispatcher_is_supported.argtypes = [ + ctypes.POINTER(GroupedConvProblemC), + ] + self._lib.conv_dispatcher_is_supported.restype = ctypes.c_int + self._lib.conv_dispatcher_run.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.POINTER(GroupedConvProblemC), + ctypes.c_void_p, + ] + self._lib.conv_dispatcher_run.restype = ctypes.c_float + + @classmethod + def find(cls) -> Optional["GroupedConvDispatcherLib"]: + """Search standard paths for the conv library.""" + root = Path(__file__).parent.parent + for rel in cls.SEARCH_PATHS: + path = root / rel + if path.exists(): + try: + lib = ctypes.CDLL(str(path)) + return cls(lib, path) + except OSError: + continue + return None + + @property + def path(self) -> Path: + return self._path + + def initialize(self): + self._lib.conv_dispatcher_init() + + def cleanup(self): + self._lib.conv_dispatcher_cleanup() + + def version(self) -> str: + return self._lib.conv_dispatcher_version().decode() + + def has_forward(self) -> bool: + return self._lib.conv_dispatcher_has_kernels() != 0 + + def has_bwd_data(self) -> bool: + return self._lib.conv_dispatcher_has_bwd_data() != 0 + + def has_bwd_weight(self) -> bool: + return self._lib.conv_dispatcher_has_bwd_weight() != 0 + + def kernel_count(self) -> int: + return self._lib.conv_dispatcher_get_kernel_count() + + def kernel_names(self) -> List[str]: + names = [] + for i in range(self.kernel_count()): + buf = ctypes.create_string_buffer(256) + if self._lib.conv_dispatcher_get_kernel_name(i, buf, 256) == 0: + names.append(buf.value.decode()) + return names + + def is_supported(self, problem: GroupedConvProblem) -> bool: + pc = GroupedConvProblemC.from_problem(problem) + return self._lib.conv_dispatcher_is_supported(ctypes.byref(pc)) != 0 + + def run( + self, a_ptr: int, b_ptr: int, c_ptr: int, problem: GroupedConvProblem + ) -> float: + """Run convolution. Returns time_ms (>0 success, <0 error).""" + pc = GroupedConvProblemC.from_problem(problem) + return self._lib.conv_dispatcher_run( + a_ptr, b_ptr, c_ptr, ctypes.byref(pc), None + ) + + +# ============================================================================= +# GpuGroupedConvRunner +# ============================================================================= + + +class GpuGroupedConvRunner: + """High-level GPU convolution runner. + + Handles library loading, HIP memory management, and kernel execution. + Follows the same pattern as the old GpuConvRunner from conv_utils.py. + + Usage: + runner = GpuGroupedConvRunner() + if runner.is_available(): + result = runner.run(input_np, weight_np, problem) + print(f"Time: {result.time_ms:.4f} ms, TFLOPS: {result.tflops:.2f}") + """ + + HIP_MEMCPY_H2D = 1 + HIP_MEMCPY_D2H = 2 + + def __init__(self, lib_path: Optional[str] = None): + self._dispatch_lib: Optional[GroupedConvDispatcherLib] = None + self._hip = None + self._initialized = False + + try: + if lib_path: + lib = ctypes.CDLL(lib_path) + self._dispatch_lib = GroupedConvDispatcherLib(lib, Path(lib_path)) + else: + self._dispatch_lib = GroupedConvDispatcherLib.find() + + if self._dispatch_lib is None: + return + + self._hip = ctypes.CDLL("libamdhip64.so") + self._hip.hipMalloc.argtypes = [ + ctypes.POINTER(ctypes.c_void_p), + ctypes.c_size_t, + ] + self._hip.hipMalloc.restype = ctypes.c_int + self._hip.hipFree.argtypes = [ctypes.c_void_p] + self._hip.hipFree.restype = ctypes.c_int + self._hip.hipMemcpy.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_size_t, + ctypes.c_int, + ] + self._hip.hipMemcpy.restype = ctypes.c_int + self._hip.hipDeviceSynchronize.argtypes = [] + self._hip.hipDeviceSynchronize.restype = ctypes.c_int + + self._dispatch_lib.initialize() + self._initialized = True + except Exception: + self._initialized = False + + def is_available(self) -> bool: + return self._initialized and self._dispatch_lib is not None + + @property + def library_path(self) -> Optional[str]: + if self._dispatch_lib: + return str(self._dispatch_lib.path) + return None + + @property + def lib(self) -> Optional[GroupedConvDispatcherLib]: + return self._dispatch_lib + + def run( + self, + input_np: np.ndarray, + weight_np: np.ndarray, + problem: GroupedConvProblem, + output_np: Optional[np.ndarray] = None, + ) -> GroupedConvResult: + """Run convolution on GPU. + + Args: + input_np: For forward: X (NHWGC). For bwd_data: dY. For bwd_weight: X. + weight_np: For forward: W (GKYXC). For bwd_data: W. For bwd_weight: dY. + problem: Problem specification. + output_np: Optional pre-allocated output buffer. + + Returns: + GroupedConvResult with success, time_ms, tflops, output. + """ + if not self.is_available(): + return GroupedConvResult(error="GPU not available") + + try: + # Determine output shape based on direction + d = problem.direction + if d == "bwd_data": + out_shape = problem.input_shape() + elif d == "bwd_weight": + out_shape = problem.weight_shape() + else: + out_shape = problem.output_shape() + + if output_np is None: + output_np = np.zeros(out_shape, dtype=input_np.dtype) + + output_size = output_np.nbytes + + # Allocate GPU memory + d_a, d_b, d_c = ctypes.c_void_p(), ctypes.c_void_p(), ctypes.c_void_p() + self._hip.hipMalloc(ctypes.byref(d_a), input_np.nbytes) + self._hip.hipMalloc(ctypes.byref(d_b), weight_np.nbytes) + self._hip.hipMalloc(ctypes.byref(d_c), output_size) + + # Host to device + self._hip.hipMemcpy( + d_a, input_np.ctypes.data, input_np.nbytes, self.HIP_MEMCPY_H2D + ) + self._hip.hipMemcpy( + d_b, weight_np.ctypes.data, weight_np.nbytes, self.HIP_MEMCPY_H2D + ) + self._hip.hipDeviceSynchronize() + + # Launch kernel + time_ms = self._dispatch_lib.run(d_a.value, d_b.value, d_c.value, problem) + self._hip.hipDeviceSynchronize() + + result = GroupedConvResult() + + if time_ms > 0: + # Device to host + self._hip.hipMemcpy( + output_np.ctypes.data, d_c, output_size, self.HIP_MEMCPY_D2H + ) + self._hip.hipDeviceSynchronize() + result.success = True + result.time_ms = time_ms + result.tflops = problem.flops / (time_ms * 1e9) + result.output = output_np + else: + result.error = ( + "unsupported" + if time_ms == -3.0 + else "no kernel" + if time_ms == -2.0 + else f"error (code {time_ms})" + ) + + # Free GPU memory + self._hip.hipFree(d_a) + self._hip.hipFree(d_b) + self._hip.hipFree(d_c) + + return result + + except Exception as e: + return GroupedConvResult(error=str(e)) + + def cleanup(self): + if self._dispatch_lib: + try: + self._dispatch_lib.cleanup() + except Exception: + pass + + +# ============================================================================= +# GroupedConvRegistry +# ============================================================================= + + +class GroupedConvRegistry: + """Collection of grouped conv kernel configs with JSON export/import.""" + + def __init__(self, name: str = "default"): + self.name = name + self._kernels: List[GroupedConvKernelConfig] = [] + + def add(self, config: GroupedConvKernelConfig): + self._kernels.append(config) + + @property + def kernels(self) -> List[GroupedConvKernelConfig]: + return list(self._kernels) + + def __len__(self) -> int: + return len(self._kernels) + + def select( + self, problem: "GroupedConvProblem", heuristic=None + ) -> Optional[GroupedConvKernelConfig]: + """Select the best kernel for a problem. + + Args: + problem: The convolution problem. + heuristic: Optional callable(problem) -> List[str] returning + ranked kernel name substrings. The registry tries + each in order; falls back to first matching kernel. + + Returns: + The best matching GroupedConvKernelConfig, or None. + """ + matching = [k for k in self._kernels if k.variant == problem.direction] + if not matching: + return None + + if heuristic is not None: + ranked = heuristic(problem) + for hint in ranked: + for k in matching: + if hint in k.name: + return k + + return matching[0] if matching else None + + def filter_by_variant(self, variant: str) -> "GroupedConvRegistry": + variant = _resolve_variant(variant) + reg = GroupedConvRegistry(f"{self.name}_{variant}") + for k in self._kernels: + if k.variant == variant: + reg.add(k) + return reg + + def filter_by_arch(self, arch: str) -> "GroupedConvRegistry": + reg = GroupedConvRegistry(f"{self.name}_{arch}") + for k in self._kernels: + if k.arch == arch: + reg.add(k) + return reg + + def to_json(self, indent: int = 2) -> str: + return json.dumps( + { + "name": self.name, + "kernels": [k.to_json_obj() for k in self._kernels], + }, + indent=indent, + ) + + @classmethod + def from_json(cls, json_str: str) -> "GroupedConvRegistry": + data = json.loads(json_str) + reg = cls(data.get("name", "imported")) + for kd in data.get("kernels", []): + sig = kd.get("signature", {}) + algo = kd.get("algorithm", {}) + wave = algo.get("wave", "2x2x1").split("x") + warp = algo.get("warp", "32x32x16").split("x") + vec = algo.get("vector_sizes", [4, 8, 8]) + reg.add( + GroupedConvKernelConfig( + variant=sig.get("variant", "forward"), + ndim_spatial=sig.get("ndim_spatial", 2), + dtype=sig.get("dtype", "fp16"), + layout=sig.get("layout", "nhwgc"), + arch=kd.get("arch", "gfx942"), + tile_m=algo.get("tile_m", 1), + tile_n=algo.get("tile_n", 128), + tile_k=algo.get("tile_k", 128), + wave_m=int(wave[0]), + wave_n=int(wave[1]), + wave_k=int(wave[2]), + warp_tile_m=int(warp[0]), + warp_tile_n=int(warp[1]), + warp_tile_k=int(warp[2]), + pipeline=algo.get("pipeline", "compv3"), + epilogue=algo.get("epilogue", "cshuffle"), + scheduler=algo.get("scheduler", "intrawave"), + vector_size_a=vec[0] if len(vec) > 0 else 4, + vector_size_b=vec[1] if len(vec) > 1 else 8, + vector_size_c=vec[2] if len(vec) > 2 else 8, + block_per_cu=algo.get("block_per_cu", 1), + num_wave_groups=algo.get("num_wave_groups", 1), + num_groups_to_merge=algo.get("num_groups_to_merge", 1), + ) + ) + return reg + + def build( + self, + verbose: bool = False, + max_workers: Optional[int] = None, + ) -> Dict[Tuple[str, int], "GpuGroupedConvRunner"]: + """Parallel JIT compile all kernels in this registry. + + Args: + verbose: Print progress during build. + max_workers: Max parallel codegen/compile processes (default: cpu_count capped at 8). + + Returns a dict mapping (variant, ndim_spatial) to a ready-to-use + GpuGroupedConvRunner. + """ + if not self._kernels: + return {} + + libs = setup_multiple_grouped_conv_dispatchers( + self._kernels, + verbose=verbose, + max_workers=max_workers, + ) + + runners: Dict[Tuple[str, int], GpuGroupedConvRunner] = {} + for cfg, lib in zip(self._kernels, libs): + if lib is None: + continue + key = (cfg.variant, cfg.ndim_spatial) + if key in runners: + continue + runner = GpuGroupedConvRunner(lib_path=str(lib.path)) + if runner.is_available(): + runners[key] = runner + return runners + + def print_registry(self, indent: str = " "): + print(f"{indent}Registry '{self.name}': {len(self)} kernels") + for i, k in enumerate(self._kernels): + print( + f"{indent} [{i}] {k.name} (valid={validate_grouped_conv_config(k.to_dict()).is_valid})" + ) + + +# ============================================================================= +# GroupedConvValidationResult +# ============================================================================= + + +@dataclass +class GroupedConvValidationResult(ValidationResultBase): + """Result of grouped conv kernel config validation.""" + + variant: str = "forward" + + def __init__( + self, + is_valid=True, + errors=None, + warnings=None, + suggested_fixes=None, + variant="forward", + ): + super().__init__( + is_valid=is_valid, + errors=errors or [], + warnings=warnings or [], + suggested_fixes=suggested_fixes or {}, + ) + self.variant = variant + + +# ============================================================================= +# Validation helpers (extracted from the original config extraction code) +# ============================================================================= + + +def _first(val): + if isinstance(val, list) and len(val) > 0: + return val[0] + return val + + +def _get_tile_config(config: dict) -> dict: + return config.get("tile_config") or {} + + +def _get_trait_config(config: dict) -> dict: + return config.get("trait_config") or {} + + +def _extract_wave_config(tile_config: dict) -> List[int]: + wm = tile_config.get("wave_m") or tile_config.get("warp_m") + wn = tile_config.get("wave_n") or tile_config.get("warp_n") + wk = tile_config.get("wave_k") or tile_config.get("warp_k") + if wm is not None and wn is not None and wk is not None: + return [_first(wm), _first(wn), _first(wk)] + return [2, 2, 1] + + +def _extract_warp_tile_config(tile_config: dict) -> List[int]: + wtm = tile_config.get("warp_tile_m") or tile_config.get("warp_m") + wtn = tile_config.get("warp_tile_n") or tile_config.get("warp_n") + wtk = tile_config.get("warp_tile_k") or tile_config.get("warp_k") + if wtm is not None and wtn is not None and wtk is not None: + return [_first(wtm), _first(wtn), _first(wtk)] + return [32, 32, 16] + + +def _extract_trait_values(trait_config: dict) -> Tuple[str, str, str]: + p = _first(trait_config.get("pipeline", "compv4")) + e = _first(trait_config.get("epilogue", "cshuffle")) + s = _first(trait_config.get("scheduler", "intrawave")) + if isinstance(p, list): + p = p[0] if p else "compv4" + if isinstance(e, list): + e = e[0] if e else "cshuffle" + if isinstance(s, list): + s = s[0] if s else "intrawave" + return (str(p), str(e), str(s)) + + +# ============================================================================= +# validate_grouped_conv_config / auto_correct_grouped_conv_config +# ============================================================================= + + +def validate_grouped_conv_config(config: dict) -> GroupedConvValidationResult: + """Validate a grouped conv kernel config dict. + + Accepts either a raw dict (legacy) or GroupedConvKernelConfig.to_dict() output. + """ + errors: List[str] = [] + warnings: List[str] = [] + suggested_fixes: Dict[str, Any] = {} + + required = ( + "tile_config", + "trait_config", + "variant", + "ndim_spatial", + "arch", + "layout", + ) + for key in required: + if key not in config: + errors.append(f"Missing required key: {key}") + if errors: + return GroupedConvValidationResult( + is_valid=False, + errors=errors, + warnings=warnings, + suggested_fixes=suggested_fixes, + variant=config.get("variant", "forward"), + ) + + tile_config = _get_tile_config(config) + trait_config = _get_trait_config(config) + variant = _first(config.get("variant", "forward")) + if isinstance(variant, list): + variant = variant[0] if variant else "forward" + variant = _resolve_variant(str(variant)) + + ndim_spatial = config.get("ndim_spatial") + arch = config.get("arch", "gfx942") + dtype = config.get("dtype", "fp16") + + if variant not in VALID_VARIANTS: + errors.append(f"Invalid variant: {variant}. Valid: {', '.join(VALID_VARIANTS)}") + suggested_fixes["variant"] = "forward" + + if ndim_spatial is not None: + ndim = ndim_spatial + if isinstance(ndim, list): + ndim = ndim[0] if ndim else 2 + if ndim not in VALID_NDIM_SPATIAL: + errors.append( + f"Invalid ndim_spatial: {ndim}. Valid: {', '.join(map(str, VALID_NDIM_SPATIAL))}" + ) + suggested_fixes["ndim_spatial"] = 2 + + pipeline, epilogue, scheduler = _extract_trait_values(trait_config) + if variant in BACKWARD_VARIANTS and pipeline not in BACKWARD_PIPELINES: + errors.append( + f"Backward variant '{variant}' requires pipeline compv3 or mem, got {pipeline}" + ) + suggested_fixes["pipeline"] = "compv3" + + ok, msg = validate_trait_combo(pipeline, epilogue, scheduler) + if not ok: + errors.append(msg) + suggested_fixes["scheduler"] = "intrawave" + + wave_cfg = _extract_wave_config(tile_config) + ok, msg = validate_wave_config(wave_cfg, arch) + if not ok: + errors.append(msg) + arch_data = get_arch_filter_data() + valid_waves = arch_data["warp_combos"].get(arch, [[2, 2, 1]]) + if valid_waves: + suggested_fixes["wave_m"] = valid_waves[0][0] + suggested_fixes["wave_n"] = valid_waves[0][1] + suggested_fixes["wave_k"] = valid_waves[0][2] + + warp_cfg = _extract_warp_tile_config(tile_config) + ok, msg = validate_warp_tile_config(warp_cfg, arch, dtype) + if not ok: + errors.append(msg) + arch_data = get_arch_filter_data() + acc = "int32" if dtype == "int8" else "fp32" + dtype_key = f"{dtype}_{dtype}_{acc}" + valid_tiles = ( + arch_data["warp_tile_combos"] + .get(arch, {}) + .get(dtype_key, [[32, 32, 16], [16, 16, 16]]) + ) + if valid_tiles: + suggested_fixes["warp_tile_m"] = valid_tiles[0][0] + suggested_fixes["warp_tile_n"] = valid_tiles[0][1] + suggested_fixes["warp_tile_k"] = valid_tiles[0][2] + + arch_data = get_arch_filter_data() + if arch not in arch_data["supported_archs"]: + errors.append( + f"Unsupported architecture: {arch}. Supported: {', '.join(arch_data['supported_archs'])}" + ) + + return GroupedConvValidationResult( + is_valid=len(errors) == 0, + errors=errors, + warnings=warnings, + suggested_fixes=suggested_fixes, + variant=variant, + ) + + +def auto_correct_grouped_conv_config( + config: dict, +) -> Tuple[dict, GroupedConvValidationResult]: + """Auto-correct invalid grouped conv config. Returns (corrected, result).""" + result = validate_grouped_conv_config(config) + corrected = copy.deepcopy(config) + + if result.is_valid: + return corrected, result + + tile_config = corrected.setdefault("tile_config", {}) + trait_config = corrected.setdefault("trait_config", {}) + + wave_cfg = _extract_wave_config(tile_config) + arch = config.get("arch", "gfx942") + fixed_wave = auto_correct_wave(wave_cfg, arch) + tile_config["wave_m"] = fixed_wave[0] + tile_config["wave_n"] = fixed_wave[1] + tile_config["wave_k"] = fixed_wave[2] + + pipeline, epilogue, scheduler = _extract_trait_values(trait_config) + fixed_pipeline, fixed_scheduler = auto_correct_trait(pipeline, scheduler) + trait_config["pipeline"] = fixed_pipeline + trait_config["scheduler"] = fixed_scheduler + + variant = _first(config.get("variant", "forward")) + if isinstance(variant, list): + variant = variant[0] if variant else "forward" + variant = _resolve_variant(str(variant)) + if variant in BACKWARD_VARIANTS and fixed_pipeline not in BACKWARD_PIPELINES: + trait_config["pipeline"] = "compv3" + + if "warp_tile_m" in result.suggested_fixes: + tile_config["warp_tile_m"] = result.suggested_fixes["warp_tile_m"] + tile_config["warp_tile_n"] = result.suggested_fixes["warp_tile_n"] + tile_config["warp_tile_k"] = result.suggested_fixes["warp_tile_k"] + + result = validate_grouped_conv_config(corrected) + return corrected, result + + +def _run_hipcc_subprocess(args: dict) -> Tuple[bool, Optional[Path], str]: + """Run one hipcc compile+link job in a subprocess worker.""" + import subprocess + from pathlib import Path + + compile_cmd = args["compile_cmd"] + link_cmd = args["link_cmd"] + lib_path = Path(args["lib_path"]) + + try: + res_c = subprocess.run(compile_cmd, capture_output=True, text=True, timeout=300) + if res_c.returncode != 0: + return False, None, f"Compile failed: {res_c.stderr[:400]}" + + res_l = subprocess.run(link_cmd, capture_output=True, text=True, timeout=300) + if res_l.returncode != 0: + return False, None, f"Link failed: {res_l.stderr[:400]}" + + return True, lib_path, "" + except subprocess.TimeoutExpired: + return False, None, "Timeout" + except Exception as e: + return False, None, f"Error: {e}" + + +def _run_conv_codegen_subprocess(args: dict) -> Tuple[bool, Optional[str], str]: + """Run grouped-conv codegen once and return generated kernel header path.""" + import subprocess + from pathlib import Path + + out_dir = Path(args["output_dir"]) + out_dir.mkdir(parents=True, exist_ok=True) + + # Remove stale kernels so header discovery is exact for this invocation. + for stale in out_dir.glob("grouped_conv_*.hpp"): + stale.unlink(missing_ok=True) + for stale in out_dir.glob("include_all_grouped_conv_*.hpp"): + stale.unlink(missing_ok=True) + + try: + res = subprocess.run(args["cmd"], capture_output=True, text=True, timeout=300) + if res.returncode != 0: + err = (res.stderr or res.stdout or "").strip()[:500] + return False, None, f"Codegen failed: {err}" + + generated = sorted( + out_dir.glob("grouped_conv_*.hpp"), + key=lambda p: p.stat().st_mtime, + reverse=True, + ) + if not generated: + return False, None, "Codegen produced no grouped_conv_*.hpp header" + + return True, str(generated[0]), "" + except subprocess.TimeoutExpired: + return False, None, "Codegen timed out" + except Exception as e: + return False, None, f"Codegen error: {e}" + + +def _config_key(c: GroupedConvKernelConfig) -> Tuple[Any, ...]: + return ( + c.variant, + c.ndim_spatial, + c.dtype, + c.layout, + c.arch, + c.tile_m, + c.tile_n, + c.tile_k, + c.wave_m, + c.wave_n, + c.wave_k, + c.warp_tile_m, + c.warp_tile_n, + c.warp_tile_k, + c.pipeline, + c.epilogue, + c.scheduler, + ) + + +def _parse_triplet(value: str) -> Tuple[int, int, int]: + parts = value.split("x") + if len(parts) != 3: + raise ValueError(f"Invalid triplet: {value}") + return int(parts[0]), int(parts[1]), int(parts[2]) + + +def _list_arch_valid_grouped_conv_configs( + codegen_script: Path, + arch: str, + dtype: str, + variant: str, + ndim_spatial: int, +) -> List[GroupedConvKernelConfig]: + """Query codegen defaults for this (arch, dtype, variant, ndim) tuple.""" + import re + import sys + + cmd = [ + sys.executable, + str(codegen_script), + "--list-configs", + "--arch", + arch, + "--datatype", + dtype, + "--variant", + variant, + "--ndim", + str(ndim_spatial), + ] + res = subprocess.run(cmd, capture_output=True, text=True, timeout=180) + if res.returncode != 0: + return [] + + # Example: + # grouped_conv_fwd_fp16_nhwgc_2d_compv3_cshuffle_intrawave_128x128x32_2x2x1_32x32x16 + name_re = re.compile( + r"^grouped_conv_(fwd|bwd_data|bwd_weight|bwdd|bwdw)_([a-z0-9]+)_([a-z0-9]+)_([123])d_" + r"([a-z0-9]+)_([a-z0-9]+)_([a-z0-9]+)_" + r"([0-9]+x[0-9]+x[0-9]+)_([0-9]+x[0-9]+x[0-9]+)_([0-9]+x[0-9]+x[0-9]+)" + r"(?:_.*)?$" + ) + short_to_variant = { + "fwd": "forward", + "bwd_data": "bwd_data", + "bwd_weight": "bwd_weight", + "bwdd": "bwd_data", + "bwdw": "bwd_weight", + } + + out: List[GroupedConvKernelConfig] = [] + seen = set() + for raw in res.stdout.splitlines(): + line = raw.strip() + if not line.startswith("- grouped_conv_"): + continue + name = line[2:].strip() + m = name_re.match(name) + if not m: + continue + + v_short, dt, layout, ndim, pipe, epi, sched, tile_s, wave_s, warp_s = m.groups() + tm, tn, tk = _parse_triplet(tile_s) + wm, wn, wk = _parse_triplet(wave_s) + wtm, wtn, wtk = _parse_triplet(warp_s) + + cfg = GroupedConvKernelConfig( + variant=short_to_variant[v_short], + ndim_spatial=int(ndim), + dtype=dt, + layout=layout, + arch=arch, + tile_m=tm, + tile_n=tn, + tile_k=tk, + wave_m=wm, + wave_n=wn, + wave_k=wk, + warp_tile_m=wtm, + warp_tile_n=wtn, + warp_tile_k=wtk, + pipeline=pipe, + epilogue=epi, + scheduler=sched, + ) + key = _config_key(cfg) + if key not in seen: + out.append(cfg) + seen.add(key) + + return out + + +def _select_best_arch_valid_conv_config( + requested: GroupedConvKernelConfig, + candidates: List[GroupedConvKernelConfig], +) -> GroupedConvKernelConfig: + """Pick nearest arch-valid config while preferring trait exact matches.""" + + def score(c: GroupedConvKernelConfig) -> Tuple[int, int, int, int, int, int]: + tile_delta = ( + abs(c.tile_m - requested.tile_m) + + abs(c.tile_n - requested.tile_n) + + abs(c.tile_k - requested.tile_k) + ) + wave_delta = ( + abs(c.wave_m - requested.wave_m) + + abs(c.wave_n - requested.wave_n) + + abs(c.wave_k - requested.wave_k) + ) + warp_tile_delta = ( + abs(c.warp_tile_m - requested.warp_tile_m) + + abs(c.warp_tile_n - requested.warp_tile_n) + + abs(c.warp_tile_k - requested.warp_tile_k) + ) + return ( + 0 if c.pipeline == requested.pipeline else 1, + 0 if c.scheduler == requested.scheduler else 1, + 0 if c.epilogue == requested.epilogue else 1, + tile_delta, + wave_delta, + warp_tile_delta, + ) + + best = min(candidates, key=score) + selected = copy.deepcopy(best) + selected.arch = requested.arch + return selected + + +def _write_single_conv_dispatch_header( + config: GroupedConvKernelConfig, + kernel_header: Path, + dispatch_header: Path, +) -> None: + """Create a tiny dispatch header consumed by conv_ctypes_lib.cpp.""" + macros: List[str] = [] + aliases: List[str] = [] + + if config.variant == "forward": + kernel_name_symbol = "CONV_FWD_KERNEL_NAME" + if config.ndim_spatial == 3: + macros.append("#define CONV_FWD_3D_AVAILABLE 1") + aliases.append("using ConvFwd3dLauncher = SelectedConvKernelLauncher;") + else: + macros.append("#define CONV_FWD_2D_AVAILABLE 1") + elif config.variant == "bwd_data": + kernel_name_symbol = "CONV_BWD_DATA_KERNEL_NAME" + if config.ndim_spatial == 3: + macros.append("#define CONV_BWD_DATA_3D_AVAILABLE 1") + aliases.append("using ConvBwdData3dLauncher = SelectedConvBwdDataLauncher;") + else: + macros.append("#define CONV_BWD_DATA_2D_AVAILABLE 1") + else: + kernel_name_symbol = "CONV_BWD_WEIGHT_KERNEL_NAME" + if config.ndim_spatial == 3: + macros.append("#define CONV_BWD_WEIGHT_3D_AVAILABLE 1") + aliases.append( + "using ConvBwdWeight3dLauncher = SelectedConvBwdWeightLauncher;" + ) + else: + macros.append("#define CONV_BWD_WEIGHT_2D_AVAILABLE 1") + + content = ( + "// Auto-generated single-kernel dispatch header for Python JIT\n" + "#pragma once\n\n" + f'#include "{kernel_header.name}"\n\n' + + "\n".join(macros) + + "\n\n" + + "\n".join(aliases) + + "\n\n" + + f"static const char* CONV_KERNEL_NAMES[] = {{{kernel_name_symbol}}};\n" + + "static constexpr int CONV_KERNEL_COUNT = 1;\n" + ) + dispatch_header.write_text(content) + + +class GroupedConvCodegenRunner: + """Generate and compile grouped-conv JIT libraries in parallel.""" + + def __init__(self, max_workers: Optional[int] = None): + import multiprocessing + + self.max_workers = max_workers or min(multiprocessing.cpu_count(), 8) + self.root = Path(__file__).parent.parent + self.build_dir = self.root / "build" + self.codegen_script = self.root / "codegen" / "unified_grouped_conv_codegen.py" + + def generate_and_compile_parallel( + self, + configs: List[GroupedConvKernelConfig], + verbose: bool = True, + ) -> List[Optional[Path]]: + import sys + from concurrent.futures import ProcessPoolExecutor, as_completed + + if not configs: + return [] + + if not self.build_dir.exists(): + self.build_dir.mkdir(parents=True, exist_ok=True) + + ctypes_source = self.root / "bindings" / "ctypes" / "conv_ctypes_lib.cpp" + static_lib = self.build_dir / "libck_tile_dispatcher.a" + jit_root = self.build_dir / "generated_kernels" / "python_jit" + jit_root.mkdir(parents=True, exist_ok=True) + (self.build_dir / "examples").mkdir(parents=True, exist_ok=True) + + if not self.codegen_script.exists(): + if verbose: + print(f"Codegen script missing: {self.codegen_script}") + return [None] * len(configs) + if not ctypes_source.exists() or not static_lib.exists(): + if verbose: + print("Missing conv ctypes source or static dispatcher library") + return [None] * len(configs) + + if verbose: + print( + f"Generating {len(configs)} grouped-conv kernels in parallel " + f"(workers={self.max_workers})..." + ) + + gen_jobs: List[Dict[str, Any]] = [] + job_dirs: List[Path] = [] + for i, c in enumerate(configs): + cfg_dir = jit_root / f"cfg_{i}" + cfg_dir.mkdir(parents=True, exist_ok=True) + job_dirs.append(cfg_dir) + + cmd = [ + sys.executable, + str(self.codegen_script), + "--output", + str(cfg_dir), + "--datatype", + c.dtype, + "--variant", + c.variant, + "--ndim", + str(c.ndim_spatial), + "--arch", + c.arch, + "--tile-m", + str(c.tile_m), + "--tile-n", + str(c.tile_n), + "--tile-k", + str(c.tile_k), + "--warp-m", + str(c.wave_m), + "--warp-n", + str(c.wave_n), + "--warp-k", + str(c.wave_k), + "--warp-tile-m", + str(c.warp_tile_m), + "--warp-tile-n", + str(c.warp_tile_n), + "--warp-tile-k", + str(c.warp_tile_k), + "--pipeline", + c.pipeline, + "--scheduler", + c.scheduler, + "--epilogue", + c.epilogue, + ] + gen_jobs.append({"cmd": cmd, "output_dir": str(cfg_dir)}) + + generated_headers: List[Optional[Path]] = [None] * len(configs) + with ProcessPoolExecutor(max_workers=self.max_workers) as executor: + futures = { + executor.submit(_run_conv_codegen_subprocess, job): idx + for idx, job in enumerate(gen_jobs) + } + for future in as_completed(futures): + idx = futures[future] + ok, header_path, err = future.result() + if ok and header_path: + generated_headers[idx] = Path(header_path) + if verbose: + print(f" OK [{idx}] codegen: {Path(header_path).name}") + else: + if verbose: + print(f" FAIL [{idx}] codegen: {err}") + + if verbose: + compile_count = sum(1 for h in generated_headers if h is not None) + print( + f"Compiling {compile_count} grouped-conv libraries in parallel " + f"(workers={self.max_workers})..." + ) + + compile_jobs: List[Dict[str, Any]] = [] + compile_to_input_index: Dict[int, int] = {} + for i, c in enumerate(configs): + hdr_path = generated_headers[i] + if hdr_path is None: + continue + + cfg_dir = job_dirs[i] + dispatch_header = cfg_dir / "conv_python_dispatch.hpp" + _write_single_conv_dispatch_header(c, hdr_path, dispatch_header) + + lib_name = ( + f"libdispatcher_conv_{c.variant}_{c.ndim_spatial}d_{c.dtype}_" + f"{c.tile_str}_{c.wave_str}_{c.warp_str}_{c.pipeline}_{c.scheduler}.so" + ) + lib_path = self.build_dir / "examples" / lib_name + obj_file = lib_path.with_suffix(".o") + + compile_cmd = [ + "/opt/rocm/bin/hipcc", + "-c", + "-fPIC", + "-O3", + f"-I{self.root / 'include'}", + f"-I{self.root.parent / 'include'}", + f"-I{self.root.parent}", + f"-I{cfg_dir}", + "-DCK_TILE_SINGLE_KERNEL_INCLUDE", + f"-include{dispatch_header}", + "-D__HIP_PLATFORM_AMD__", + f"--offload-arch={c.arch}", + f'-DGFX_ARCH="{c.arch}"', + "-mllvm", + "-enable-noalias-to-md-conversion=0", + "-Wno-undefined-func-template", + "-Wno-float-equal", + str(ctypes_source), + "-o", + str(obj_file), + ] + link_cmd = [ + "/opt/rocm/bin/hipcc", + "-shared", + "-fPIC", + f"--offload-arch={c.arch}", + "--hip-link", + str(obj_file), + str(static_lib), + "-o", + str(lib_path), + ] + + compile_to_input_index[len(compile_jobs)] = i + compile_jobs.append( + { + "compile_cmd": compile_cmd, + "link_cmd": link_cmd, + "lib_path": str(lib_path), + "config_name": c.name, + } + ) + + results_map: Dict[int, Optional[Path]] = {i: None for i in range(len(configs))} + with ProcessPoolExecutor(max_workers=self.max_workers) as executor: + futures = { + executor.submit(_run_hipcc_subprocess, job): j + for j, job in enumerate(compile_jobs) + } + for future in as_completed(futures): + job_idx = futures[future] + idx = compile_to_input_index[job_idx] + success, lib_path, err = future.result() + if success and lib_path: + results_map[idx] = Path(lib_path) + if verbose: + status = "OK" if success else f"FAIL ({err})" + name = ( + Path(lib_path).name + if success and lib_path + else compile_jobs[job_idx]["config_name"] + ) + print(f" {status} {name}") + + return [results_map.get(i) for i in range(len(configs))] + + +# ============================================================================= +# Convenience functions +# ============================================================================= + + +def get_grouped_conv_default_config( + variant: str = "forward", + ndim_spatial: int = 2, + arch: str = "gfx942", + dtype: str = "fp16", +) -> GroupedConvKernelConfig: + """Return a valid default GroupedConvKernelConfig.""" + return GroupedConvKernelConfig( + variant=variant, + ndim_spatial=ndim_spatial, + arch=arch, + dtype=dtype, + ) + + +def format_grouped_conv_summary(config) -> str: + """Format a config (dict or GroupedConvKernelConfig) into a human-readable string.""" + if isinstance(config, GroupedConvKernelConfig): + lines = [ + f"Grouped Conv Config: {config.variant} {config.ndim_spatial}D", + f" Arch: {config.arch}", + f" Layout: {config.layout}", + f" Dtype: {config.dtype}", + f" Tile: {config.tile_str}", + f" Wave: {config.wave_str}", + f" Warp: {config.warp_str}", + f" Traits: pipeline={config.pipeline} epilogue={config.epilogue} scheduler={config.scheduler}", + ] + return "\n".join(lines) + + # Legacy dict support + tile_config = _get_tile_config(config) if isinstance(config, dict) else {} + trait_config = _get_trait_config(config) if isinstance(config, dict) else {} + variant = config.get("variant", "?") if isinstance(config, dict) else "?" + ndim = config.get("ndim_spatial", "?") if isinstance(config, dict) else "?" + arch = config.get("arch", "?") if isinstance(config, dict) else "?" + layout = config.get("layout", "?") if isinstance(config, dict) else "?" + dtype = config.get("dtype", "fp16") if isinstance(config, dict) else "fp16" + + lines = [f"Grouped Conv Config: {variant} {ndim}D"] + lines.append(f" Arch: {arch}") + lines.append(f" Layout: {layout}") + lines.append(f" Dtype: {dtype}") + + if tile_config: + wave = _extract_wave_config(tile_config) + warp = _extract_warp_tile_config(tile_config) + lines.append( + f" Tile: M={_first(tile_config.get('tile_m', 1))} N={_first(tile_config.get('tile_n', 128))} K={_first(tile_config.get('tile_k', 128))}" + ) + lines.append(f" Wave: {wave[0]}x{wave[1]}x{wave[2]}") + lines.append(f" Warp: {warp[0]}x{warp[1]}x{warp[2]}") + + if trait_config: + pipeline = _first(trait_config.get("pipeline", "?")) + epilogue = _first(trait_config.get("epilogue", "?")) + scheduler = _first(trait_config.get("scheduler", "?")) + lines.append( + f" Traits: pipeline={pipeline} epilogue={epilogue} scheduler={scheduler}" + ) + + return "\n".join(lines) if lines else "(empty config)" + + +def setup_multiple_grouped_conv_dispatchers( + configs: List[GroupedConvKernelConfig], + verbose: bool = True, + max_workers: Optional[int] = None, +) -> List[Optional[GroupedConvDispatcherLib]]: + """ + Setup multiple grouped-conv dispatchers in parallel. + + This keeps architecture filtering strict: + 1. Validate + auto-correct each requested config + 2. Query codegen's arch-valid config set for each (arch, dtype, variant, ndim) + 3. Map each request to nearest valid config + 4. Parallel codegen + parallel compile + """ + if not configs: + return [] + + codegen_script = ( + Path(__file__).parent.parent / "codegen" / "unified_grouped_conv_codegen.py" + ) + arch_valid_cache: Dict[ + Tuple[str, str, str, int], List[GroupedConvKernelConfig] + ] = {} + + selected_configs: List[Optional[GroupedConvKernelConfig]] = [] + for i, original in enumerate(configs): + c = copy.deepcopy(original) + + val = validate_grouped_conv_config(c.to_dict()) + if not val.is_valid: + corrected, corrected_result = auto_correct_grouped_conv_config(c.to_dict()) + if not corrected_result.is_valid: + if verbose: + print(f" FAIL [{i}] config remains invalid after auto-correct") + selected_configs.append(None) + continue + + tile_cfg = corrected.get("tile_config", {}) + trait_cfg = corrected.get("trait_config", {}) + c.variant = _resolve_variant( + str(_first(corrected.get("variant", c.variant))) + ) + c.ndim_spatial = int(_first(corrected.get("ndim_spatial", c.ndim_spatial))) + c.arch = str(corrected.get("arch", c.arch)) + c.layout = str(corrected.get("layout", c.layout)) + c.dtype = str(corrected.get("dtype", c.dtype)) + c.tile_m = int(_first(tile_cfg.get("tile_m", c.tile_m))) + c.tile_n = int(_first(tile_cfg.get("tile_n", c.tile_n))) + c.tile_k = int(_first(tile_cfg.get("tile_k", c.tile_k))) + c.wave_m = int(_first(tile_cfg.get("wave_m", c.wave_m))) + c.wave_n = int(_first(tile_cfg.get("wave_n", c.wave_n))) + c.wave_k = int(_first(tile_cfg.get("wave_k", c.wave_k))) + c.warp_tile_m = int(_first(tile_cfg.get("warp_tile_m", c.warp_tile_m))) + c.warp_tile_n = int(_first(tile_cfg.get("warp_tile_n", c.warp_tile_n))) + c.warp_tile_k = int(_first(tile_cfg.get("warp_tile_k", c.warp_tile_k))) + c.pipeline = str(_first(trait_cfg.get("pipeline", c.pipeline))) + c.scheduler = str(_first(trait_cfg.get("scheduler", c.scheduler))) + c.epilogue = str(_first(trait_cfg.get("epilogue", c.epilogue))) + + cache_key = (c.arch, c.dtype, c.variant, c.ndim_spatial) + if cache_key not in arch_valid_cache: + arch_valid_cache[cache_key] = _list_arch_valid_grouped_conv_configs( + codegen_script=codegen_script, + arch=c.arch, + dtype=c.dtype, + variant=c.variant, + ndim_spatial=c.ndim_spatial, + ) + if verbose and not arch_valid_cache[cache_key]: + print( + f" FAIL [{i}] no arch-valid configs listed for " + f"{c.arch}/{c.dtype}/{c.variant}/{c.ndim_spatial}d" + ) + + candidates = arch_valid_cache[cache_key] + if not candidates: + selected_configs.append(None) + continue + + selected = _select_best_arch_valid_conv_config(c, candidates) + if verbose and _config_key(selected) != _config_key(c): + print( + f" INFO [{i}] mapped to arch-valid config: " + f"{selected.tile_str} {selected.wave_str} {selected.warp_str} " + f"{selected.pipeline}/{selected.scheduler}/{selected.epilogue}" + ) + selected_configs.append(selected) + + unique_configs: List[GroupedConvKernelConfig] = [] + unique_index_by_key: Dict[Tuple[Any, ...], int] = {} + input_to_unique: List[Optional[int]] = [] + for cfg in selected_configs: + if cfg is None: + input_to_unique.append(None) + continue + key = _config_key(cfg) + if key not in unique_index_by_key: + unique_index_by_key[key] = len(unique_configs) + unique_configs.append(cfg) + input_to_unique.append(unique_index_by_key[key]) + + runner = GroupedConvCodegenRunner(max_workers=max_workers) + unique_lib_paths = runner.generate_and_compile_parallel( + unique_configs, verbose=verbose + ) + + libs: List[Optional[GroupedConvDispatcherLib]] = [] + loaded_cache: Dict[int, Optional[GroupedConvDispatcherLib]] = {} + for input_idx, unique_idx in enumerate(input_to_unique): + if unique_idx is None: + libs.append(None) + continue + + if unique_idx in loaded_cache: + libs.append(loaded_cache[unique_idx]) + continue + + path = ( + unique_lib_paths[unique_idx] if unique_idx < len(unique_lib_paths) else None + ) + disp: Optional[GroupedConvDispatcherLib] = None + if path and path.exists(): + try: + lib = ctypes.CDLL(str(path)) + disp = GroupedConvDispatcherLib(lib, path) + disp.initialize() + except Exception as e: + if verbose: + print(f" FAIL [{input_idx}] failed to load {path}: {e}") + loaded_cache[unique_idx] = disp + libs.append(disp) + + return libs + + +def detect_gpu_arch() -> str: + """Detect GPU architecture using rocminfo.""" + try: + out = subprocess.check_output( + ["rocminfo"], stderr=subprocess.DEVNULL, text=True + ) + for line in out.split("\n"): + if "gfx" in line.lower() and "name:" in line.lower(): + for part in line.split(): + if part.startswith("gfx"): + return part + except Exception: + pass + return "gfx942" diff --git a/dispatcher/scripts/compile_gemm_examples.py b/dispatcher/scripts/compile_gemm_examples.py index b19c18a13a..98ba18ab51 100644 --- a/dispatcher/scripts/compile_gemm_examples.py +++ b/dispatcher/scripts/compile_gemm_examples.py @@ -94,17 +94,17 @@ def find_hipcc() -> str: def extract_conv_kernel_declarations(source_file: Path) -> list: - """Extract CONVOLUTION kernel declarations from C++ source file. + """Extract GROUPED CONVOLUTION kernel declarations from C++ source file. - Supports DECL_CONV_KERNEL_SET macro with ConvSig/ConvAlgo pattern. + Supports DECL_GROUPED_CONV_KERNEL_SET macro with ConvSig/ConvAlgo pattern. Extracts all parameters: dtype, layout, conv_type, dims, tile, wave, warp, pipeline, scheduler. """ content = source_file.read_text() declarations = [] seen = set() - # Pattern: DECL_CONV_KERNEL_SET(name, .add(...).add(...)) - set_pattern = r"DECL_CONV_KERNEL_SET\s*\(\s*(\w+)\s*,([^;]+)\)" + # Pattern: DECL_GROUPED_CONV_KERNEL_SET(name, .add(...).add(...)) + set_pattern = r"DECL_GROUPED_CONV_KERNEL_SET\s*\(\s*(\w+)\s*,([^;]+)\)" for match in re.finditer(set_pattern, content, re.DOTALL): set_name = match.group(1) @@ -396,24 +396,23 @@ def expand_conv_declaration_with_arch_filter(decl: dict, arch: str = "gfx942") - def generate_conv_kernels(declarations: list, gpu_target: str = "gfx942") -> int: - """Generate convolution kernels using unified_conv_codegen.""" + """Generate grouped convolution kernels using unified_grouped_conv_codegen.""" kernel_dir = get_generated_kernels_dir() kernel_dir.mkdir(parents=True, exist_ok=True) - # Import conv codegen codegen_dir = get_dispatcher_root() / "codegen" sys.path.insert(0, str(codegen_dir)) try: - from unified_conv_codegen import ( - UnifiedConvCodegen, - ConvKernelConfig, - ConvVariant, + from unified_grouped_conv_codegen import ( + UnifiedGroupedConvCodegen as UnifiedConvCodegen, + GroupedConvKernelConfig as ConvKernelConfig, + GroupedConvVariant as ConvVariant, TileConfig, - TraitConfig, + GroupedConvTraitConfig as TraitConfig, ) except ImportError as e: - print_error(f" Failed to import conv codegen: {e}") + print_error(f" Failed to import grouped conv codegen: {e}") return 0 codegen = UnifiedConvCodegen(kernel_dir) @@ -1564,9 +1563,9 @@ def build_exact_conv_kernel_filename(decl: dict) -> str: if conv_type == "forward": type_prefix = "fwd" elif conv_type == "bwd_data": - type_prefix = "bwdd" + type_prefix = "bwd_data" elif conv_type == "bwd_weight": - type_prefix = "bwdw" + type_prefix = "bwd_weight" else: type_prefix = conv_type @@ -1601,9 +1600,9 @@ def generate_specific_conv_kernel(decl: dict, gpu_target: str = "gfx942") -> boo else: variant = "forward" - # Use unified_conv_codegen + # Use unified_grouped_conv_codegen codegen_dir = get_dispatcher_root() / "codegen" - codegen_script = codegen_dir / "unified_conv_codegen.py" + codegen_script = codegen_dir / "unified_grouped_conv_codegen.py" output_dir = get_generated_kernels_dir() cmd = [ @@ -1661,9 +1660,9 @@ def find_conv_kernel_header(decl: dict, gpu_target: str = "gfx942") -> Path: if conv_type == "forward": type_prefix = "fwd" elif conv_type == "bwd_data": - type_prefix = "bwdd" + type_prefix = "bwd_data" elif conv_type == "bwd_weight": - type_prefix = "bwdw" + type_prefix = "bwd_weight" else: type_prefix = conv_type @@ -1865,7 +1864,9 @@ In your C++ code, declare kernels like: if not gemm_declarations and not conv_declarations: print_error(" No kernel declarations found!") - print(" Add DECL_KERNEL_SET for GEMM or DECL_CONV_KERNEL_SET for Conv") + print( + " Add DECL_KERNEL_SET for GEMM or DECL_GROUPED_CONV_KERNEL_SET for Grouped Conv" + ) return 1 # Handle GEMM declarations @@ -1913,7 +1914,7 @@ In your C++ code, declare kernels like: is_valid, error_msg = validate_kernel_config(decl, arch) if not is_valid: - print(f"\n ⚠ Invalid configuration: {decl_name}") + print(f"\n WARNING Invalid configuration: {decl_name}") # Parse the error and show specific auto-corrections corrections = [] @@ -1926,7 +1927,7 @@ In your C++ code, declare kernels like: decl["wave_m"] = -1 decl["wave_n"] = -1 corrections.append( - f"wave: {original_values['wave']} → [wildcard expansion]" + f"wave: {original_values['wave']} -> [wildcard expansion]" ) if "warp tile" in error_msg.lower(): @@ -1936,7 +1937,7 @@ In your C++ code, declare kernels like: decl["warp_m"] = -1 decl["warp_n"] = -1 corrections.append( - f"warp_tile: {original_values['warp']} → [wildcard expansion]" + f"warp_tile: {original_values['warp']} -> [wildcard expansion]" ) if "trait combination" in error_msg.lower(): @@ -1945,16 +1946,16 @@ In your C++ code, declare kernels like: decl["pipeline"] = "*" decl["scheduler"] = "*" corrections.append( - f"pipeline: {original_values['pipeline']} → [wildcard expansion]" + f"pipeline: {original_values['pipeline']} -> [wildcard expansion]" ) corrections.append( - f"scheduler: {original_values['scheduler']} → [wildcard expansion]" + f"scheduler: {original_values['scheduler']} -> [wildcard expansion]" ) # Print the auto-corrections print(" AUTO-CORRECTION:") for corr in corrections: - print(f" • {corr}") + print(f" - {corr}") auto_corrections.append((decl_name, corrections)) invalid_count += 1 @@ -1962,15 +1963,15 @@ In your C++ code, declare kernels like: if invalid_count > 0: print( - f"\n ⚠ {invalid_count} invalid config(s) auto-corrected via wildcard expansion" + f"\n WARNING {invalid_count} invalid config(s) auto-corrected via wildcard expansion" ) if wildcard_count > 0: print( - f" ✓ {len(gemm_declarations) - wildcard_count} explicit + {wildcard_count} wildcard (will expand)" + f" OK {len(gemm_declarations) - wildcard_count} explicit + {wildcard_count} wildcard (will expand)" ) else: - print(f" ✓ All {len(gemm_declarations)} configurations valid") + print(f" OK All {len(gemm_declarations)} configurations valid") # Expand GEMM declarations (for wildcards) print("\n Expanding wildcards to valid configurations...") @@ -1994,7 +1995,7 @@ In your C++ code, declare kernels like: wave_str = f"[{exp['wave_m']}, {exp['wave_n']}, {exp['wave_k']}]" warp_str = f"[{exp['warp_m']}, {exp['warp_n']}, {exp['warp_k']}]" print( - f" → wave={wave_str}, warp={warp_str}, pipeline={exp['pipeline']}, scheduler={exp['scheduler']}" + f" -> wave={wave_str}, warp={warp_str}, pipeline={exp['pipeline']}, scheduler={exp['scheduler']}" ) if len(expanded) > 3: print(f" ... and {len(expanded) - 3} more") @@ -2002,11 +2003,11 @@ In your C++ code, declare kernels like: exp = expanded[0] wave_str = f"[{exp['wave_m']}, {exp['wave_n']}, {exp['wave_k']}]" warp_str = f"[{exp['warp_m']}, {exp['warp_n']}, {exp['warp_k']}]" - print(f" {decl_name}: → wave={wave_str}, warp={warp_str}") + print(f" {decl_name}: -> wave={wave_str}, warp={warp_str}") if len(expanded_gemm) > len(gemm_declarations): print( - f"\n Total: {len(gemm_declarations)} declarations → {len(expanded_gemm)} configurations" + f"\n Total: {len(gemm_declarations)} declarations -> {len(expanded_gemm)} configurations" ) gemm_declarations = expanded_gemm @@ -2054,7 +2055,7 @@ In your C++ code, declare kernels like: is_valid, error_msg = validate_conv_kernel_config(decl, arch) if not is_valid: - print(f"\n ⚠ Invalid conv configuration: {decl_name}") + print(f"\n WARNING Invalid conv configuration: {decl_name}") # Parse the error and show specific auto-corrections corrections = [] @@ -2067,7 +2068,7 @@ In your C++ code, declare kernels like: decl["wave_m"] = -1 decl["wave_n"] = -1 corrections.append( - f"wave: {original_values['wave']} → [wildcard expansion]" + f"wave: {original_values['wave']} -> [wildcard expansion]" ) if "warp tile" in error_msg.lower(): @@ -2077,7 +2078,7 @@ In your C++ code, declare kernels like: decl["warp_m"] = -1 decl["warp_n"] = -1 corrections.append( - f"warp_tile: {original_values['warp']} → [wildcard expansion]" + f"warp_tile: {original_values['warp']} -> [wildcard expansion]" ) if "trait combination" in error_msg.lower(): @@ -2086,16 +2087,16 @@ In your C++ code, declare kernels like: decl["pipeline"] = "*" decl["scheduler"] = "*" corrections.append( - f"pipeline: {original_values['pipeline']} → [wildcard expansion]" + f"pipeline: {original_values['pipeline']} -> [wildcard expansion]" ) corrections.append( - f"scheduler: {original_values['scheduler']} → [wildcard expansion]" + f"scheduler: {original_values['scheduler']} -> [wildcard expansion]" ) # Print the auto-corrections print(" AUTO-CORRECTION:") for corr in corrections: - print(f" • {corr}") + print(f" - {corr}") auto_corrections.append((decl_name, corrections)) invalid_count += 1 @@ -2103,15 +2104,15 @@ In your C++ code, declare kernels like: if invalid_count > 0: print( - f"\n ⚠ {invalid_count} invalid config(s) auto-corrected via wildcard expansion" + f"\n WARNING {invalid_count} invalid config(s) auto-corrected via wildcard expansion" ) if wildcard_count > 0: print( - f" ✓ {len(conv_declarations) - wildcard_count} explicit + {wildcard_count} wildcard (will expand)" + f" OK {len(conv_declarations) - wildcard_count} explicit + {wildcard_count} wildcard (will expand)" ) else: - print(f" ✓ All {len(conv_declarations)} configurations valid") + print(f" OK All {len(conv_declarations)} configurations valid") # Expand Conv declarations (for wildcards) print("\n Expanding wildcards to valid configurations...") @@ -2134,7 +2135,7 @@ In your C++ code, declare kernels like: wave_str = f"[{exp['wave_m']}, {exp['wave_n']}, {exp['wave_k']}]" warp_str = f"[{exp['warp_m']}, {exp['warp_n']}, {exp['warp_k']}]" print( - f" → wave={wave_str}, warp={warp_str}, pipeline={exp['pipeline']}, scheduler={exp['scheduler']}" + f" -> wave={wave_str}, warp={warp_str}, pipeline={exp['pipeline']}, scheduler={exp['scheduler']}" ) if len(expanded) > 3: print(f" ... and {len(expanded) - 3} more") @@ -2142,11 +2143,11 @@ In your C++ code, declare kernels like: exp = expanded[0] wave_str = f"[{exp['wave_m']}, {exp['wave_n']}, {exp['wave_k']}]" warp_str = f"[{exp['warp_m']}, {exp['warp_n']}, {exp['warp_k']}]" - print(f" {decl_name}: → wave={wave_str}, warp={warp_str}") + print(f" {decl_name}: -> wave={wave_str}, warp={warp_str}") if len(expanded_conv) > len(conv_declarations): print( - f"\n Total: {len(conv_declarations)} declarations → {len(expanded_conv)} configurations" + f"\n Total: {len(conv_declarations)} declarations -> {len(expanded_conv)} configurations" ) conv_declarations = expanded_conv diff --git a/dispatcher/scripts/compile_grouped_conv_examples.py b/dispatcher/scripts/compile_grouped_conv_examples.py new file mode 100644 index 0000000000..32fe70a2de --- /dev/null +++ b/dispatcher/scripts/compile_grouped_conv_examples.py @@ -0,0 +1,882 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Self-contained build script for C++ grouped convolution examples. + +Parses DECL_GROUPED_CONV_KERNEL_SET declarations from source files, +generates the needed kernels, and compiles the example. + +Includes validation and auto-correction via wildcard expansion. + +Usage: + python3 compile_grouped_conv_examples.py examples/grouped_conv/cpp/02_grouped_conv_forward.cpp + python3 compile_grouped_conv_examples.py examples/grouped_conv/cpp/03_grouped_conv_validation.cpp --no-compile +""" + +import argparse +import os +import re +import subprocess +import sys +from concurrent.futures import ProcessPoolExecutor, as_completed +from pathlib import Path +from typing import Optional + +# Setup paths +SCRIPT_DIR = Path(__file__).parent.resolve() +DISPATCHER_DIR = SCRIPT_DIR.parent +CK_ROOT = DISPATCHER_DIR.parent + +sys.path.insert(0, str(DISPATCHER_DIR / "python")) +sys.path.insert(0, str(DISPATCHER_DIR / "codegen")) + +from dispatcher_common import ( # noqa: E402 + print_phase, + print_success, + print_error, + print_info, + find_hipcc, + get_arch_filter_data, + get_build_dir, + get_ck_root, + get_dispatcher_root, + get_generated_kernels_dir, +) + + +def extract_grouped_conv_declarations(source_file: Path) -> list: + """Extract DECL_GROUPED_CONV_KERNEL_SET declarations from C++ source.""" + content = source_file.read_text() + declarations = [] + + # Pattern: DECL_GROUPED_CONV_KERNEL_SET(name, .add(...).add(...)) + # Find all DECL_GROUPED_CONV_KERNEL_SET blocks by matching parentheses + pattern_start = r"DECL_GROUPED_CONV_KERNEL_SET\s*\(\s*(\w+)\s*," + for match in re.finditer(pattern_start, content): + set_name = match.group(1) + start_pos = match.end() + + # Find matching closing paren by counting parens + paren_count = 1 # We're already inside the first paren + end_pos = start_pos + for i, c in enumerate(content[start_pos:]): + if c == "(": + paren_count += 1 + elif c == ")": + paren_count -= 1 + if paren_count == 0: + end_pos = start_pos + i + break + + set_body = content[start_pos:end_pos] + + # Pattern 1: Simple add("dtype", "layout", "conv_type", tile_k, tile_c) + simple_add = ( + r'\.add\s*\(\s*"(\w+)"\s*,\s*"(\w+)"\s*,\s*"(\w+)"\s*,\s*(\d+)\s*,\s*(\d+)' + ) + for add_match in re.finditer(simple_add, set_body): + conv_type = add_match.group(3) + default_pipeline = ( + "compv3" if conv_type in ("bwd_data", "bwd_weight") else "compv4" + ) + declarations.append( + { + "set": set_name, + "dtype": add_match.group(1), + "layout": add_match.group(2), + "conv_type": conv_type, + "tile_k": int(add_match.group(4)), + "tile_c": int(add_match.group(5)), + "num_dims": 2, + "pipeline": default_pipeline, + "scheduler": "intrawave", + "wave_m": 2, + "wave_n": 2, + "wave_k": 1, + "warp_m": 32, + "warp_n": 32, + "warp_k": 16, + "arch": "gfx942", + } + ) + + # Pattern 2: Full ConvSig()/ConvAlgo() specification + # Find all .add( positions that start with ConvSig() + full_add = r"\.add\s*\(\s*ConvSig\(\)" + add_positions = [m.start() for m in re.finditer(full_add, set_body)] + + for pos in add_positions: + # Find matching closing paren by counting parens + paren_count = 0 + in_add = False + end = pos + for i, c in enumerate(set_body[pos:]): + if c == "(": + paren_count += 1 + in_add = True + elif c == ")": + paren_count -= 1 + if in_add and paren_count == 0: + end = pos + i + 1 + break + + add_str = set_body[pos:end] + + # Extract signature part (between ConvSig() and ConvAlgo()) + sig_match = re.search(r"ConvSig\(\)(.*?)ConvAlgo\(\)", add_str, re.DOTALL) + if not sig_match: + continue + sig_str = sig_match.group(1) + + # Extract algorithm part (between ConvAlgo() and arch string) + algo_match = re.search( + r'ConvAlgo\(\)(.*?),\s*"(\w+)"\s*\)', add_str, re.DOTALL + ) + if not algo_match: + continue + algo_str = algo_match.group(1) + arch = algo_match.group(2) + + # Parse signature + dtype = "fp16" + dtype_match = re.search(r'\.dtype\s*\(\s*"(\w+)"', sig_str) + if dtype_match: + dtype = dtype_match.group(1) + + layout = "nhwgc" + layout_match = re.search(r'\.layout\s*\(\s*"(\w+)"', sig_str) + if layout_match: + layout = layout_match.group(1) + + conv_type = "forward" + conv_type_match = re.search(r'\.conv_type\s*\(\s*"(\w+)"', sig_str) + if conv_type_match: + conv_type = conv_type_match.group(1) + + num_dims = 2 + dims_match = re.search(r"\.dims\s*\(\s*(\d+)", sig_str) + if dims_match: + num_dims = int(dims_match.group(1)) + + # Parse algorithm + tile_k, tile_c = 128, 128 + tile_match = re.search( + r"\.tile\s*\(\s*\d+\s*,\s*(\d+)\s*,\s*(\d+)", algo_str + ) + if tile_match: + tile_k = int(tile_match.group(1)) + tile_c = int(tile_match.group(2)) + + wave_m, wave_n, wave_k = 2, 2, 1 + wave_match = re.search( + r"\.wave\s*\(\s*(\d+)\s*,\s*(\d+)(?:\s*,\s*(\d+))?", algo_str + ) + if wave_match: + wave_m = int(wave_match.group(1)) + wave_n = int(wave_match.group(2)) + wave_k = int(wave_match.group(3) or 1) + + warp_m, warp_n, warp_k = 32, 32, 16 + warp_match = re.search( + r"\.warp\s*\(\s*(\d+)\s*,\s*(\d+)(?:\s*,\s*(\d+))?", algo_str + ) + if warp_match: + warp_m = int(warp_match.group(1)) + warp_n = int(warp_match.group(2)) + warp_k = int(warp_match.group(3) or 16) + + pipeline = "compv4" + pipeline_match = re.search(r'\.pipeline\s*\(\s*"(\w+)"', algo_str) + if pipeline_match: + pipeline = pipeline_match.group(1) + + scheduler = "intrawave" + scheduler_match = re.search(r'\.scheduler\s*\(\s*"(\w+)"', algo_str) + if scheduler_match: + scheduler = scheduler_match.group(1) + + # Parse additional parameters + vector_a, vector_b, vector_c = 4, 8, 8 + vector_match = re.search( + r"\.vector_sizes\s*\(\s*(\d+)\s*,\s*(\d+)\s*,\s*(\d+)", algo_str + ) + if vector_match: + vector_a = int(vector_match.group(1)) + vector_b = int(vector_match.group(2)) + vector_c = int(vector_match.group(3)) + + block_per_cu = 1 + block_per_cu_match = re.search(r"\.block_per_cu\s*\(\s*(\d+)", algo_str) + if block_per_cu_match: + block_per_cu = int(block_per_cu_match.group(1)) + + memory_op = "set" + memory_op_match = re.search(r'\.memory_op\s*\(\s*"(\w+)"', algo_str) + if memory_op_match: + memory_op = memory_op_match.group(1) + + epilogue = "cshuffle" + epilogue_match = re.search(r'\.epilogue\s*\(\s*"(\w+)"', algo_str) + if epilogue_match: + epilogue = epilogue_match.group(1) + + # Parse num_wave_groups (for V5 pipeline) + num_wave_groups = 1 + nwg_match = re.search(r"\.num_wave_groups\s*\(\s*(\d+)", algo_str) + if nwg_match: + num_wave_groups = int(nwg_match.group(1)) + + # Parse num_groups_to_merge (for merged group grouped convolution) + num_groups_to_merge = 1 + ngm_match = re.search(r"\.num_groups_to_merge\s*\(\s*(\d+)", algo_str) + if ngm_match: + num_groups_to_merge = int(ngm_match.group(1)) + + # Parse double_smem_buffer (for V4 pipeline) + double_smem_buffer = False + dsb_match = re.search( + r"\.double_smem_buffer\s*\(\s*(true|false)", algo_str, re.I + ) + if dsb_match: + double_smem_buffer = dsb_match.group(1).lower() == "true" + + # Parse padding flags + pad_m, pad_n, pad_k = True, True, True + padding_match = re.search( + r"\.padding\s*\(\s*(true|false)\s*,\s*(true|false)\s*,\s*(true|false)", + algo_str, + re.I, + ) + if padding_match: + pad_m = padding_match.group(1).lower() == "true" + pad_n = padding_match.group(2).lower() == "true" + pad_k = padding_match.group(3).lower() == "true" + + declarations.append( + { + "set": set_name, + "dtype": dtype, + "layout": layout, + "conv_type": conv_type, + "tile_k": tile_k, + "tile_c": tile_c, + "num_dims": num_dims, + "pipeline": pipeline, + "scheduler": scheduler, + "wave_m": wave_m, + "wave_n": wave_n, + "wave_k": wave_k, + "warp_m": warp_m, + "warp_n": warp_n, + "warp_k": warp_k, + "vector_a": vector_a, + "vector_b": vector_b, + "vector_c": vector_c, + "block_per_cu": block_per_cu, + "memory_op": memory_op, + "epilogue": epilogue, + "num_wave_groups": num_wave_groups, + "num_groups_to_merge": num_groups_to_merge, + "double_smem_buffer": double_smem_buffer, + "pad_m": pad_m, + "pad_n": pad_n, + "pad_k": pad_k, + "arch": arch, + } + ) + + return declarations + + +# ============================================================================= +# VALIDATION AND AUTO-CORRECTION +# ============================================================================= + + +def is_grouped_conv_wildcard_declaration(decl: dict) -> bool: + """Check if a declaration uses wildcards (-1 or '*').""" + wildcard_fields = ["wave_m", "wave_n", "warp_m", "warp_n", "pipeline", "scheduler"] + for field in wildcard_fields: + val = decl.get(field) + if val == -1 or val == "*": + return True + return False + + +def validate_grouped_conv_kernel_config(decl: dict, arch: str = "gfx942") -> tuple: + """Validate a grouped conv kernel configuration against known supported combinations. + + Returns: (is_valid, error_message) + """ + # Skip validation for wildcards - expansion will filter invalid combos + if is_grouped_conv_wildcard_declaration(decl): + return (True, None) + + arch_data = get_arch_filter_data() + + pipeline = decl.get("pipeline", "compv4") + scheduler = decl.get("scheduler", "intrawave") + dtype = decl.get("dtype", "fp16") + + wave_m = decl.get("wave_m", 2) + wave_n = decl.get("wave_n", 2) + wave_k = decl.get("wave_k", 1) + + warp_m = decl.get("warp_m", 32) + warp_n = decl.get("warp_n", 32) + warp_k = decl.get("warp_k", 16) + + errors = [] + + # Check trait combination (pipeline, epilogue, scheduler) + combo = (pipeline, "cshuffle", scheduler) + if combo in arch_data["trait_unsupported"]: + errors.append( + f"Unsupported trait combination: pipeline={pipeline}, scheduler={scheduler}\n" + f" Valid schedulers for {pipeline}: intrawave" + ) + + # Check wave configuration for this arch + warp_combos = arch_data["warp_combos"].get(arch, [[2, 2, 1]]) + wave_cfg = [wave_m, wave_n, wave_k] + if wave_cfg not in warp_combos: + valid_str = ", ".join(f"[{c[0]},{c[1]},{c[2]}]" for c in warp_combos) + errors.append( + f"Unsupported wave configuration [{wave_m},{wave_n},{wave_k}] for {arch}\n" + f" Valid wave configs: {valid_str}" + ) + + # Check warp tile configuration for this arch and dtype + acc_dtype = "int32" if dtype == "int8" else "fp32" + dtype_key = f"{dtype}_{dtype}_{acc_dtype}" + warp_tile_combos = ( + arch_data["warp_tile_combos"] + .get(arch, {}) + .get(dtype_key, [[32, 32, 16], [16, 16, 16], [16, 16, 32]]) + ) + warp_cfg = [warp_m, warp_n, warp_k] + if warp_cfg not in warp_tile_combos: + valid_str = ", ".join(f"[{c[0]},{c[1]},{c[2]}]" for c in warp_tile_combos[:5]) + errors.append( + f"Unsupported warp tile [{warp_m},{warp_n},{warp_k}] for {arch}/{dtype}\n" + f" Valid warp tiles: {valid_str}" + ) + + # Check arch is supported + if arch not in arch_data["supported_archs"]: + errors.append( + f"Unsupported architecture: {arch}\n" + f" Supported: {', '.join(arch_data['supported_archs'])}" + ) + + if errors: + return (False, "\n".join(errors)) + + return (True, None) + + +def expand_grouped_conv_declaration_with_arch_filter( + decl: dict, arch: str = "gfx942" +) -> list: + """Expand a grouped conv declaration with wildcards into valid configurations. + + Wildcards: + - wave_m/wave_n = -1: Try all valid wave configs for this arch + - warp_m/warp_n = -1: Try all valid warp tiles for this arch/dtype + - pipeline/scheduler = "*": Try all valid combinations + + Returns a list of fully-specified declarations. + """ + arch_data = get_arch_filter_data() + dtype = decl.get("dtype", "fp16") + + # Get valid combinations for this arch + valid_wave_combos = arch_data["warp_combos"].get(arch, [[2, 2, 1]]) + acc_dtype = "int32" if dtype == "int8" else "fp32" + dtype_key = f"{dtype}_{dtype}_{acc_dtype}" + valid_warp_tiles = ( + arch_data["warp_tile_combos"] + .get(arch, {}) + .get(dtype_key, [[32, 32, 16], [16, 16, 16]]) + ) + + # Valid pipelines and schedulers + valid_pipelines = ["compv3", "compv4"] + valid_schedulers = ["intrawave"] # interwave often unsupported + + # Determine which fields need expansion + expand_wave = decl.get("wave_m", 2) == -1 or decl.get("wave_n", 2) == -1 + expand_warp = decl.get("warp_m", 32) == -1 or decl.get("warp_n", 32) == -1 + expand_pipeline = decl.get("pipeline", "compv4") == "*" + expand_scheduler = decl.get("scheduler", "intrawave") == "*" + + # Build combinations + wave_options = ( + valid_wave_combos + if expand_wave + else [[decl.get("wave_m", 2), decl.get("wave_n", 2), decl.get("wave_k", 1)]] + ) + warp_options = ( + valid_warp_tiles + if expand_warp + else [[decl.get("warp_m", 32), decl.get("warp_n", 32), decl.get("warp_k", 16)]] + ) + pipeline_options = ( + valid_pipelines if expand_pipeline else [decl.get("pipeline", "compv4")] + ) + scheduler_options = ( + valid_schedulers if expand_scheduler else [decl.get("scheduler", "intrawave")] + ) + + expanded = [] + for wave in wave_options: + for warp in warp_options: + for pipeline in pipeline_options: + for scheduler in scheduler_options: + # Skip known invalid combinations + if (pipeline, "cshuffle", scheduler) in arch_data[ + "trait_unsupported" + ]: + continue + + new_decl = decl.copy() + new_decl["wave_m"] = wave[0] + new_decl["wave_n"] = wave[1] + new_decl["wave_k"] = wave[2] + new_decl["warp_m"] = warp[0] + new_decl["warp_n"] = warp[1] + new_decl["warp_k"] = warp[2] + new_decl["pipeline"] = pipeline + new_decl["scheduler"] = scheduler + + expanded.append(new_decl) + + # If no valid expansions, return original (will fail validation later) + if not expanded: + return [decl] + + # Return first valid config (or all if needed) + return expanded[:1] # Just use first valid config for grouped conv + + +def validate_and_expand_grouped_conv_declarations( + declarations: list, arch: str, verbose: bool = False +) -> list: + """Validate declarations and auto-correct invalid ones via wildcard expansion.""" + print(f"\n Validating against {arch} arch filter...") + + wildcard_count = 0 + invalid_count = 0 + auto_corrections = [] + + for decl in declarations: + decl_arch = decl.get("arch", arch) + decl_name = ( + f"{decl['dtype']}_{decl['conv_type']}_{decl['tile_k']}x{decl['tile_c']}" + ) + + # Check for wildcards + if is_grouped_conv_wildcard_declaration(decl): + wildcard_count += 1 + continue + + is_valid, error_msg = validate_grouped_conv_kernel_config(decl, decl_arch) + if not is_valid: + print(f"\n WARNING Invalid grouped conv configuration: {decl_name}") + + # Parse the error and show specific auto-corrections + corrections = [] + original_values = {} + + if "wave configuration" in error_msg.lower(): + original_values["wave"] = ( + f"[{decl.get('wave_m', 2)}, {decl.get('wave_n', 2)}, {decl.get('wave_k', 1)}]" + ) + decl["wave_m"] = -1 + decl["wave_n"] = -1 + corrections.append( + f"wave: {original_values['wave']} -> [wildcard expansion]" + ) + + if "warp tile" in error_msg.lower(): + original_values["warp"] = ( + f"[{decl.get('warp_m', 32)}, {decl.get('warp_n', 32)}, {decl.get('warp_k', 16)}]" + ) + decl["warp_m"] = -1 + decl["warp_n"] = -1 + corrections.append( + f"warp_tile: {original_values['warp']} -> [wildcard expansion]" + ) + + if "trait combination" in error_msg.lower(): + original_values["pipeline"] = decl.get("pipeline", "compv4") + original_values["scheduler"] = decl.get("scheduler", "intrawave") + decl["pipeline"] = "*" + decl["scheduler"] = "*" + corrections.append( + f"pipeline: {original_values['pipeline']} -> [wildcard expansion]" + ) + corrections.append( + f"scheduler: {original_values['scheduler']} -> [wildcard expansion]" + ) + + # Print the auto-corrections + print(" AUTO-CORRECTION:") + for corr in corrections: + print(f" - {corr}") + auto_corrections.append((decl_name, corrections)) + + invalid_count += 1 + wildcard_count += 1 + + if invalid_count > 0: + print( + f"\n WARNING {invalid_count} invalid config(s) auto-corrected via wildcard expansion" + ) + + if wildcard_count > 0: + print( + f" OK {len(declarations) - wildcard_count} explicit + {wildcard_count} wildcard (will expand)" + ) + else: + print(f" OK All {len(declarations)} configurations valid") + + # Expand wildcards + print("\n Expanding wildcards to valid configurations...") + expanded_declarations = [] + for decl in declarations: + decl_arch = decl.get("arch", arch) + decl_name = ( + f"{decl['dtype']}_{decl['conv_type']}_{decl['tile_k']}x{decl['tile_c']}" + ) + + expanded = expand_grouped_conv_declaration_with_arch_filter(decl, decl_arch) + expanded_declarations.extend(expanded) + + if len(expanded) > 1: + print( + f" {decl_name}: expanded to {len(expanded)} valid configurations" + ) + for exp in expanded[:3]: + wave_str = f"[{exp['wave_m']}, {exp['wave_n']}, {exp['wave_k']}]" + warp_str = f"[{exp['warp_m']}, {exp['warp_n']}, {exp['warp_k']}]" + print( + f" -> wave={wave_str}, warp={warp_str}, pipeline={exp['pipeline']}" + ) + if len(expanded) > 3: + print(f" ... and {len(expanded) - 3} more") + elif is_grouped_conv_wildcard_declaration(decl) and len(expanded) == 1: + exp = expanded[0] + wave_str = f"[{exp['wave_m']}, {exp['wave_n']}, {exp['wave_k']}]" + warp_str = f"[{exp['warp_m']}, {exp['warp_n']}, {exp['warp_k']}]" + print(f" {decl_name}: -> wave={wave_str}, warp={warp_str}") + + if len(expanded_declarations) != len(declarations): + print( + f"\n Total: {len(declarations)} declarations -> {len(expanded_declarations)} configurations" + ) + + return expanded_declarations + + +def _generate_single_grouped_conv_kernel(args: tuple) -> tuple: + """Generate one grouped conv kernel (picklable for ProcessPoolExecutor). + + Args: (decl, output_dir_str, gpu_target) + Returns: (idx, filepath_str or None, error_str or None) + """ + decl, output_dir_str, gpu_target = args + output_dir = Path(output_dir_str) + idx = decl.get("_idx", 0) + + try: + from codegen_common import TileConfig + from unified_grouped_conv_codegen import ( + GroupedConvKernelConfig, + GroupedConvTraitConfig, + GroupedConvVariant, + UnifiedGroupedConvCodegen, + ) + + # Map conv_type to variant + variant = GroupedConvVariant.FORWARD + if decl["conv_type"] == "bwd_data": + variant = GroupedConvVariant.BACKWARD_DATA + elif decl["conv_type"] == "bwd_weight": + variant = GroupedConvVariant.BACKWARD_WEIGHT + + pipeline = decl.get("pipeline", "compv4") + adj_tile_k = 64 * 2 if pipeline == "compv4" else 64 + + # Create tile config (tile_m=tile_k, tile_n=tile_c for conv GEMM view) + tile = TileConfig( + tile_m=decl["tile_k"], + tile_n=decl["tile_c"], + tile_k=adj_tile_k, + warp_m=decl["wave_m"], + warp_n=decl["wave_n"], + warp_k=decl.get("wave_k", 1), + warp_tile_m=decl["warp_m"], + warp_tile_n=decl["warp_n"], + warp_tile_k=decl["warp_k"], + ) + + trait = GroupedConvTraitConfig( + pipeline=pipeline, + scheduler=decl["scheduler"], + epilogue=decl.get("epilogue", "cshuffle"), + double_smem_buffer=decl.get("double_smem_buffer", False), + pad_m=decl.get("pad_m", True), + pad_n=decl.get("pad_n", True), + pad_k=decl.get("pad_k", True), + num_groups_to_merge=decl.get("num_groups_to_merge", 1), + ) + + config = GroupedConvKernelConfig( + tile=tile, + trait=trait, + variant=variant, + ndim_spatial=decl["num_dims"], + arch=decl.get("arch", gpu_target), + vector_size_a=decl.get("vector_a", 4), + vector_size_b=decl.get("vector_b", 8), + vector_size_c=decl.get("vector_c", 8), + block_per_cu=decl.get("block_per_cu", 1), + num_wave_groups=decl.get("num_wave_groups", 1), + num_groups_to_merge=decl.get("num_groups_to_merge", 1), + double_smem_buffer=decl.get("double_smem_buffer", False), + ) + + codegen = UnifiedGroupedConvCodegen(output_dir, gpu_target=gpu_target) + kernel_path, _ = codegen.generate_kernel(config, decl["dtype"], variant) + return (idx, str(kernel_path), None) + + except Exception as e: + return (idx, None, str(e)) + + +def generate_grouped_conv_kernels( + declarations: list, + output_dir: Path, + gpu_target: str = "gfx942", + max_workers: Optional[int] = None, +) -> list: + """Generate grouped convolution kernels using unified_grouped_conv_codegen. + + Uses ProcessPoolExecutor for parallel kernel generation. + """ + output_dir.mkdir(parents=True, exist_ok=True) + + # Prepare work items (add _idx for ordering) + work_items = [] + for idx, decl in enumerate(declarations): + decl_copy = decl.copy() + decl_copy["_idx"] = idx + work_items.append((decl_copy, str(output_dir), gpu_target)) + + max_workers = max_workers or min(len(work_items), os.cpu_count() or 4) + generated = [] + failed = [] + + with ProcessPoolExecutor(max_workers=max_workers) as executor: + futures = { + executor.submit(_generate_single_grouped_conv_kernel, w): w[0]["_idx"] + for w in work_items + } + for future in as_completed(futures): + idx, path, err = future.result() + if path: + generated.append(Path(path)) + print_info(f" Generated: {Path(path).name}") + else: + failed.append((idx, err)) + print_error(f" Failed kernel {idx + 1}: {err}") + + if failed: + for idx, err in failed[:3]: + print_error(f" Kernel {idx + 1}: {err[:200]}") + if len(failed) > 3: + print_error(f" ... and {len(failed) - 3} more failures") + + return generated + + +def compile_grouped_conv_example( + source_file: Path, + output_bin: Path, + kernel_headers: list, + hipcc: str, + gpu_target: str, +) -> bool: + """Compile the C++ example with generated kernels.""" + kernel_dir = get_generated_kernels_dir() + ck_root = get_ck_root() + dispatcher_dir = get_dispatcher_root() + + includes = [ + f"-I{ck_root / 'include'}", + f"-I{dispatcher_dir / 'include'}", + f"-I{kernel_dir}", + ] + + # Build include flags for generated kernels + kernel_includes = [] + for header in kernel_headers: + kernel_includes.extend(["-include", str(header)]) + + # Add define to indicate kernels are available + defines = ["-DGROUPED_CONV_KERNEL_AVAILABLE=1"] + + cmd = [ + hipcc, + "-std=c++20", + "-O2", + f"--offload-arch={gpu_target}", + *includes, + *defines, + *kernel_includes, + "-o", + str(output_bin), + str(source_file), + ] + + print_info(f" Compiling: {source_file.name}") + result = subprocess.run(cmd, capture_output=True, text=True) + + if result.returncode != 0: + if result.stderr: + lines = result.stderr.split("\n") + errors = [line for line in lines if "error:" in line.lower()][:5] + for err_line in errors: + print_error(f" {err_line}") + return False + + return True + + +def main(): + parser = argparse.ArgumentParser( + description="Build C++ grouped convolution example with self-contained kernel generation" + ) + parser.add_argument("source", help="Source file (.cpp)") + parser.add_argument("--output", "-o", help="Output binary name") + parser.add_argument("--gpu-target", default="gfx942", help="GPU target") + parser.add_argument( + "--no-compile", action="store_true", help="Only generate kernels, don't compile" + ) + parser.add_argument("--verbose", "-v", action="store_true") + parser.add_argument( + "--jobs", + "-j", + type=int, + default=None, + help="Parallel jobs for kernel generation (default: cpu_count)", + ) + args = parser.parse_args() + + # Resolve source file + source_file = Path(args.source) + if not source_file.is_absolute(): + candidates = [ + get_dispatcher_root() / args.source, + Path.cwd() / args.source, + ] + for c in candidates: + if c.exists(): + source_file = c + break + + if not source_file.exists(): + print_error(f"Source file not found: {source_file}") + return 1 + + build_dir = get_build_dir() + kernel_dir = get_generated_kernels_dir() + output_name = args.output or source_file.stem + output_bin = build_dir / output_name + + print_success("=== Grouped Conv Example Builder (Self-Contained) ===") + + # Phase 1: Extract declarations + print_phase(1, "Scanning for DECL_GROUPED_CONV_KERNEL_SET...") + declarations = extract_grouped_conv_declarations(source_file) + + if not declarations: + print_error(" No DECL_GROUPED_CONV_KERNEL_SET declarations found!") + return 1 + + print(f" Found {len(declarations)} kernel declaration(s):") + for decl in declarations: + name = f"{decl['dtype']}_{decl['conv_type']}_{decl['num_dims']}d_{decl['tile_k']}x{decl['tile_c']}" + print(f" [{decl['set']}] {name}") + + # Phase 2: Validate and expand + print_phase(2, "Validating and expanding declarations...") + declarations = validate_and_expand_grouped_conv_declarations( + declarations, args.gpu_target, args.verbose + ) + print() + + # Phase 3: Generate kernels + print_phase(3, "Generating kernels...") + generated = generate_grouped_conv_kernels( + declarations, kernel_dir, args.gpu_target, max_workers=args.jobs + ) + + if not generated: + print_error(" No kernels generated!") + return 1 + + print(f" Generated {len(generated)} kernel file(s)") + print() + + # Phase 4: Compile (optional) + if args.no_compile: + print_info("Skipping compilation (--no-compile)") + print() + print_success("=== Kernel Generation Complete ===") + print(f"Kernels in: {kernel_dir}") + return 0 + + print_phase(4, "Compiling example...") + hipcc_path = find_hipcc() + + if not hipcc_path: + print_error(" hipcc not found. Install ROCm or set HIPCC env var.") + print(" To compile manually:") + ck_root = get_dispatcher_root().parent + print( + f" hipcc -std=c++20 -O2 -I{ck_root / 'include'} -I{get_dispatcher_root() / 'include'} \\" + ) + print(f" -I{kernel_dir} \\") + for h in generated[:1]: + print(f" -include {h} \\") + print(" -DGROUPED_CONV_KERNEL_AVAILABLE=1 \\") + print(f" --offload-arch={args.gpu_target} \\") + print(f" {source_file} -o {output_bin}") + return 1 + + build_dir.mkdir(parents=True, exist_ok=True) + + if not compile_grouped_conv_example( + source_file, output_bin, generated, hipcc_path, args.gpu_target + ): + print_error(" Compilation failed!") + return 1 + + print_success(f" Output: {output_bin}") + print() + + print_success("=== Build Complete ===") + print() + print("Run with:") + print(f" {output_bin}") + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/dispatcher/scripts/example_kernel_builder.py b/dispatcher/scripts/example_kernel_builder.py index d3bb619174..20952cd91f 100755 --- a/dispatcher/scripts/example_kernel_builder.py +++ b/dispatcher/scripts/example_kernel_builder.py @@ -55,10 +55,10 @@ def extract_balanced_parens(text: str, start_pos: int) -> str: def parse_conv_declarations(content: str) -> List[Dict]: - """Parse DECL_CONV_KERNEL_SET declarations with all parameters.""" + """Parse DECL_GROUPED_CONV_KERNEL_SET declarations with all parameters.""" kernels = [] - for match in re.finditer(r"DECL_CONV_KERNEL_SET\s*\(", content): + for match in re.finditer(r"DECL_GROUPED_CONV_KERNEL_SET\s*\(", content): body = extract_balanced_parens(content, match.end() - 1) if not body: continue @@ -619,7 +619,7 @@ def strip_cpp_strings_and_comments(content: str) -> str: n = len(content) # Patterns that indicate a string is problematic and should be stripped - problematic_patterns = ["DECL_KERNEL_SET", "DECL_CONV_KERNEL_SET", ".add("] + problematic_patterns = ["DECL_KERNEL_SET", "DECL_GROUPED_CONV_KERNEL_SET", ".add("] while i < n: # Check for raw string literal: R"delimiter(...)delimiter" @@ -697,7 +697,7 @@ def detect_and_parse(source_path: Path) -> Tuple[str, List[Dict]]: content = source_path.read_text() content = strip_cpp_strings_and_comments(content) - if "DECL_CONV_KERNEL_SET" in content: + if "DECL_GROUPED_CONV_KERNEL_SET" in content: return "conv", parse_conv_declarations(content) elif "DECL_KERNEL_SET" in content: return "gemm", parse_gemm_declarations(content) @@ -966,30 +966,128 @@ def generate_per_set_functions(source_stem: str) -> str: def generate_conv_registration( kernel_headers: List[Path], example_name: str, kernels: List[Dict] ) -> str: - """Generate Conv kernel registration code for the dispatcher registry.""" + """Generate Conv kernel registration code for the dispatcher registry. + + Creates real GroupedConvKernelInstance entries backed by the generated + launcher's launch() method via the conv backend RunFn factories. + """ if not kernel_headers: return " // No kernels to register" lines = [] - lines.append( - " (void)registry; (void)arch; // Conv uses direct launcher pattern for now" - ) - # For conv, we provide direct access to kernel launchers for i, h in enumerate(kernel_headers): - kernel_name = h.stem - lines.append(f" // Kernel {i + 1}: {kernel_name}") + kname = h.stem + ns = f"ns_{kname}" + launcher = f"{ns}::{kname}_Launcher" + + # Determine direction and ndim from the kernel header name + if "_fwd_" in kname: + direction = "Forward" + run_fn_factory = "make_conv_fwd_run_fn" + elif "_bwd_data_" in kname or "_bwdd_" in kname: + direction = "BackwardData" + run_fn_factory = "make_conv_bwd_data_run_fn" + elif "_bwd_weight_" in kname or "_bwdw_" in kname: + direction = "BackwardWeight" + run_fn_factory = "make_conv_bwd_weight_run_fn" + else: + direction = "Forward" + run_fn_factory = "make_conv_fwd_run_fn" + + ndim = 3 if "_3d_" in kname else 2 + + # Parse dtype from name (e.g. grouped_conv_fwd_fp16_...) + dtype = "fp16" + for dt in ["fp16", "bf16", "fp32"]: + if f"_{dt}_" in kname: + dtype = dt + break + + # Parse tile, wave, warp from name. + # Format: ..._TILExTILExTILE_WAVExWAVExWAVE_WARPxWARPxWARP_... + import re as _re + + tile_m, tile_n, tile_k = 1, 128, 128 + wave_m, wave_n, wave_k = 2, 2, 1 + warp_m, warp_n, warp_k = 32, 32, 16 + + triplets = _re.findall(r"_(\d+)x(\d+)x(\d+)", kname) + if len(triplets) >= 1: + tile_m, tile_n, tile_k = ( + int(triplets[0][0]), + int(triplets[0][1]), + int(triplets[0][2]), + ) + if len(triplets) >= 2: + wave_m, wave_n, wave_k = ( + int(triplets[1][0]), + int(triplets[1][1]), + int(triplets[1][2]), + ) + if len(triplets) >= 3: + warp_m, warp_n, warp_k = ( + int(triplets[2][0]), + int(triplets[2][1]), + int(triplets[2][2]), + ) + + pipeline = "compv4" if "compv4" in kname else "compv3" + scheduler = "interwave" if "interwave" in kname else "intrawave" + epilogue = "cshuffle" if "cshuffle" in kname else "default" + + # ConvConfigBase defaults + vec_a, vec_b, vec_c = 4, 8, 8 + block_per_cu = 1 + num_wave_groups = 1 + num_groups_to_merge = 1 + + lines.append(f" // Kernel {i + 1}: {kname}") + lines.append(" {") + lines.append(f" ck_tile::dispatcher::GroupedConvKernelKey key_{i};") + lines.append(f' key_{i}.dtype_in = "{dtype}";') + lines.append(f' key_{i}.dtype_wei = "{dtype}";') + lines.append(f' key_{i}.dtype_out = "{dtype}";') + lines.append(f' key_{i}.layout = "nhwgc";') + lines.append(f" key_{i}.ndim_spatial = {ndim};") + lines.append( + f" key_{i}.op = ck_tile::dispatcher::GroupedConvOp::{direction};" + ) + lines.append(f" key_{i}.tile_m = {tile_m};") + lines.append(f" key_{i}.tile_n = {tile_n};") + lines.append(f" key_{i}.tile_k = {tile_k};") + lines.append(f" key_{i}.wave_m = {wave_m};") + lines.append(f" key_{i}.wave_n = {wave_n};") + lines.append(f" key_{i}.wave_k = {wave_k};") + lines.append(f" key_{i}.warp_m = {warp_m};") + lines.append(f" key_{i}.warp_n = {warp_n};") + lines.append(f" key_{i}.warp_k = {warp_k};") + lines.append(f' key_{i}.pipeline = "{pipeline}";') + lines.append(f' key_{i}.scheduler = "{scheduler}";') + lines.append(f' key_{i}.epilogue = "{epilogue}";') + lines.append(f" key_{i}.vector_size_a = {vec_a};") + lines.append(f" key_{i}.vector_size_b = {vec_b};") + lines.append(f" key_{i}.vector_size_c = {vec_c};") + lines.append(f" key_{i}.block_per_cu = {block_per_cu};") + lines.append(f" key_{i}.num_wave_groups = {num_wave_groups};") + lines.append(f" key_{i}.num_groups_to_merge = {num_groups_to_merge};") + lines.append(f" key_{i}.arch = arch;") + lines.append( + f" auto run_fn_{i} = ck_tile::dispatcher::backends::{run_fn_factory}<{launcher}, {ndim}>();" + ) + lines.append( + f' auto inst_{i} = std::make_shared(key_{i}, "{kname}", std::move(run_fn_{i}));' + ) + lines.append(f" registry.register_kernel(key_{i}, inst_{i});") + lines.append(" }") return "\n".join(lines) -def generate_conv_kernels( - kernels: List[Dict], output_dir: Path, codegen_dir: Path -) -> bool: - """Generate Conv kernels for ALL declarations using unified codegen.""" - if not kernels: - return False - +def _build_conv_codegen_cmd( + idx: int, k: Dict, codegen_dir: Path, output_dir: Path +) -> Tuple[int, List[str], str]: + """Build the command for a single conv kernel codegen invocation.""" variant_map = { "forward": "forward", "bwd_data": "bwd_data", @@ -997,93 +1095,130 @@ def generate_conv_kernels( "bwd_weight": "bwd_weight", "backward_weight": "bwd_weight", } + variant = variant_map.get(k.get("conv_type", "forward"), "forward") + + cmd = [ + sys.executable, + str(codegen_dir / "unified_grouped_conv_codegen.py"), + "--datatype", + k.get("dtype", "fp16"), + "--variant", + variant, + "--ndim", + str(k.get("ndim", 2)), + "--output", + str(output_dir), + ] + + if k.get("tile_m"): + cmd.extend(["--tile-m", str(k["tile_m"])]) + if k.get("tile_n"): + cmd.extend(["--tile-n", str(k["tile_n"])]) + if k.get("warp_m"): + cmd.extend(["--warp-m", str(k["warp_m"])]) + if k.get("warp_n"): + cmd.extend(["--warp-n", str(k["warp_n"])]) + if k.get("warp_k"): + cmd.extend(["--warp-k", str(k["warp_k"])]) + if k.get("warp_tile_m"): + cmd.extend(["--warp-tile-m", str(k["warp_tile_m"])]) + if k.get("warp_tile_n"): + cmd.extend(["--warp-tile-n", str(k["warp_tile_n"])]) + if k.get("warp_tile_k"): + cmd.extend(["--warp-tile-k", str(k["warp_tile_k"])]) + if k.get("pipeline"): + cmd.extend(["--pipeline", k["pipeline"]]) + if k.get("scheduler"): + cmd.extend(["--scheduler", k["scheduler"]]) + if k.get("epilogue"): + cmd.extend(["--epilogue", k["epilogue"]]) + if k.get("vector_a"): + cmd.extend(["--vector-a", str(k["vector_a"])]) + if k.get("vector_b"): + cmd.extend(["--vector-b", str(k["vector_b"])]) + if k.get("vector_c"): + cmd.extend(["--vector-c", str(k["vector_c"])]) + if k.get("block_per_cu"): + cmd.extend(["--block-per-cu", str(k["block_per_cu"])]) + if k.get("num_wave_groups"): + cmd.extend(["--num-wave-groups", str(k["num_wave_groups"])]) + if k.get("num_groups_to_merge"): + cmd.extend(["--num-groups-to-merge", str(k["num_groups_to_merge"])]) + if k.get("double_smem_buffer") is not None: + cmd.extend(["--double-smem-buffer", str(k["double_smem_buffer"]).lower()]) + if k.get("tile_k"): + cmd.extend(["--tile-k", str(k["tile_k"])]) + + return (idx, cmd, str(codegen_dir)) + + +def _run_conv_codegen(args: Tuple) -> Tuple[int, bool, str]: + """Run unified_grouped_conv_codegen.py for a single kernel config (picklable for ProcessPoolExecutor).""" + idx, cmd, cwd = args + result = subprocess.run(cmd, capture_output=True, text=True, cwd=cwd) + if result.returncode != 0: + return (idx, False, result.stderr[:300]) + return (idx, True, "") + + +def generate_conv_kernels( + kernels: List[Dict], output_dir: Path, codegen_dir: Path +) -> bool: + """Generate Conv kernels for ALL declarations using unified codegen. + + Launches all codegen subprocesses in parallel via ProcessPoolExecutor + for significantly faster generation when multiple conv kernels are declared. + """ + if not kernels: + return False + + work_items = [ + _build_conv_codegen_cmd(idx, k, codegen_dir, output_dir) + for idx, k in enumerate(kernels) + ] success_count = 0 + max_workers = min(len(work_items), os.cpu_count() or 4) - # Generate a kernel for EACH declaration - for idx, k in enumerate(kernels): - variant = variant_map.get(k.get("conv_type", "forward"), "forward") - - cmd = [ - sys.executable, - str(codegen_dir / "unified_conv_codegen.py"), - "--datatype", - k.get("dtype", "fp16"), - "--variant", - variant, - "--ndim", - str(k.get("ndim", 2)), - "--output", - str(output_dir), - ] - - # Add optional parameters if specified - if k.get("tile_m"): - cmd.extend(["--tile-m", str(k["tile_m"])]) - if k.get("tile_n"): - cmd.extend(["--tile-n", str(k["tile_n"])]) - if k.get("warp_m"): - cmd.extend(["--warp-m", str(k["warp_m"])]) - if k.get("warp_n"): - cmd.extend(["--warp-n", str(k["warp_n"])]) - if k.get("warp_k"): - cmd.extend(["--warp-k", str(k["warp_k"])]) - if k.get("warp_tile_m"): - cmd.extend(["--warp-tile-m", str(k["warp_tile_m"])]) - if k.get("warp_tile_n"): - cmd.extend(["--warp-tile-n", str(k["warp_tile_n"])]) - if k.get("warp_tile_k"): - cmd.extend(["--warp-tile-k", str(k["warp_tile_k"])]) - if k.get("pipeline"): - cmd.extend(["--pipeline", k["pipeline"]]) - if k.get("scheduler"): - cmd.extend(["--scheduler", k["scheduler"]]) - if k.get("epilogue"): - cmd.extend(["--epilogue", k["epilogue"]]) - if k.get("vector_a"): - cmd.extend(["--vector-a", str(k["vector_a"])]) - if k.get("vector_b"): - cmd.extend(["--vector-b", str(k["vector_b"])]) - if k.get("vector_c"): - cmd.extend(["--vector-c", str(k["vector_c"])]) - if k.get("block_per_cu"): - cmd.extend(["--block-per-cu", str(k["block_per_cu"])]) - if k.get("num_wave_groups"): - cmd.extend(["--num-wave-groups", str(k["num_wave_groups"])]) - if k.get("num_groups_to_merge"): - cmd.extend(["--num-groups-to-merge", str(k["num_groups_to_merge"])]) - if k.get("double_smem_buffer") is not None: - cmd.extend(["--double-smem-buffer", str(k["double_smem_buffer"]).lower()]) - if k.get("tile_k"): - cmd.extend(["--tile-k", str(k["tile_k"])]) - - result = subprocess.run( - cmd, capture_output=True, text=True, cwd=str(codegen_dir) - ) - if result.returncode != 0: - print(f" Codegen error for kernel {idx + 1}: {result.stderr[:300]}") - else: - success_count += 1 + with ProcessPoolExecutor(max_workers=max_workers) as executor: + futures = {executor.submit(_run_conv_codegen, w): w[0] for w in work_items} + for future in as_completed(futures): + idx, ok, err = future.result() + if ok: + success_count += 1 + else: + print(f" Codegen error for kernel {idx + 1}: {err}") return success_count > 0 +def _run_gemm_codegen(args: Tuple) -> Tuple[int, bool, str]: + """Run unified_gemm_codegen.py for a single kernel config (picklable for ProcessPoolExecutor).""" + idx, cmd, cwd = args + result = subprocess.run(cmd, capture_output=True, text=True, cwd=cwd) + if result.returncode != 0: + return (idx, False, result.stderr[:300]) + return (idx, True, "") + + def generate_gemm_kernels( kernels: List[Dict], output_dir: Path, codegen_dir: Path ) -> bool: - """Generate GEMM kernels for ALL declarations using unified codegen.""" + """Generate GEMM kernels for ALL declarations using unified codegen. + + Launches all codegen subprocesses in parallel via ProcessPoolExecutor + for significantly faster generation when multiple kernels are declared. + """ import json if not kernels: return False - success_count = 0 - - # Generate a kernel for EACH declaration + # Build all commands upfront + work_items = [] for idx, k in enumerate(kernels): variant = "multi_d" if k.get("elementwise_op") else "standard" - # Build tile config JSON for this specific kernel tile_config = { "tile_m": [k.get("tile_m", 128)], "tile_n": [k.get("tile_n", 128)], @@ -1125,13 +1260,20 @@ def generate_gemm_kernels( config_json, ] - result = subprocess.run( - cmd, capture_output=True, text=True, cwd=str(codegen_dir) - ) - if result.returncode != 0: - print(f" Codegen error for kernel {idx + 1}: {result.stderr[:300]}") - else: - success_count += 1 + work_items.append((idx, cmd, str(codegen_dir))) + + # Run all codegen subprocesses in parallel + success_count = 0 + max_workers = min(len(work_items), os.cpu_count() or 4) + + with ProcessPoolExecutor(max_workers=max_workers) as executor: + futures = {executor.submit(_run_gemm_codegen, w): w[0] for w in work_items} + for future in as_completed(futures): + idx, ok, err = future.result() + if ok: + success_count += 1 + else: + print(f" Codegen error for kernel {idx + 1}: {err}") return success_count > 0 @@ -1229,15 +1371,17 @@ def main(): if example_type == "gemm": kernel_headers = list(args.output_dir.glob("gemm_*.hpp")) else: - k = kernels[0] if kernels else {} - variant = k.get("conv_type", "forward") prefix_map = { - "forward": "conv_fwd", - "bwd_data": "conv_bwdd", - "bwd_weight": "conv_bwdw", + "forward": "grouped_conv_fwd", + "bwd_data": "grouped_conv_bwd_data", + "bwd_weight": "grouped_conv_bwd_weight", } - prefix = prefix_map.get(variant, "conv_fwd") - kernel_headers = list(args.output_dir.glob(f"{prefix}_*.hpp")) + # Collect headers from ALL variants present in declarations + variants_used = set(k.get("conv_type", "forward") for k in kernels) + kernel_headers = [] + for variant in variants_used: + prefix = prefix_map.get(variant, "grouped_conv_fwd") + kernel_headers.extend(args.output_dir.glob(f"{prefix}_*.hpp")) if not kernel_headers: print(f"[{target_name}] No kernel headers generated!") @@ -1347,29 +1491,39 @@ def main(): ) if has_bwd_data: - bwdd_kernel = find_kernel_by_dtype_type(kernel_headers, "fp16", "_bwdd_") - if bwdd_kernel: - bwdd_ns = f"ns_{bwdd_kernel.stem}" - launcher_aliases.append( - f"using BwdDataKernelLauncher = {bwdd_ns}::{bwdd_kernel.stem}_Launcher;" + bwd_data_kernel = find_kernel_by_dtype_type( + kernel_headers, "fp16", "_bwd_data_" + ) + if not bwd_data_kernel: + bwd_data_kernel = find_kernel_by_dtype_type( + kernel_headers, "fp16", "_bwdd_" ) - if not has_fwd: # If no fwd, use bwd_data as first + if bwd_data_kernel: + bwd_data_ns = f"ns_{bwd_data_kernel.stem}" + launcher_aliases.append( + f"using BwdDataKernelLauncher = {bwd_data_ns}::{bwd_data_kernel.stem}_Launcher;" + ) + if not has_fwd: launcher_aliases.append( - f"using FirstKernelLauncher = {bwdd_ns}::{bwdd_kernel.stem}_Launcher;" + f"using FirstKernelLauncher = {bwd_data_ns}::{bwd_data_kernel.stem}_Launcher;" ) if has_bwd_weight: - bwdw_kernel = find_kernel_by_dtype_type(kernel_headers, "fp16", "_bwdw_") - if bwdw_kernel: - bwdw_ns = f"ns_{bwdw_kernel.stem}" - launcher_aliases.append( - f"using BwdWeightKernelLauncher = {bwdw_ns}::{bwdw_kernel.stem}_Launcher;" + bwd_weight_kernel = find_kernel_by_dtype_type( + kernel_headers, "fp16", "_bwd_weight_" + ) + if not bwd_weight_kernel: + bwd_weight_kernel = find_kernel_by_dtype_type( + kernel_headers, "fp16", "_bwdw_" ) - if ( - not has_fwd and not has_bwd_data - ): # If no fwd or bwdd, use bwdw as first + if bwd_weight_kernel: + bwd_weight_ns = f"ns_{bwd_weight_kernel.stem}" + launcher_aliases.append( + f"using BwdWeightKernelLauncher = {bwd_weight_ns}::{bwd_weight_kernel.stem}_Launcher;" + ) + if not has_fwd and not has_bwd_data: launcher_aliases.append( - f"using FirstKernelLauncher = {bwdw_ns}::{bwdw_kernel.stem}_Launcher;" + f"using FirstKernelLauncher = {bwd_weight_ns}::{bwd_weight_kernel.stem}_Launcher;" ) launcher_section = "\n".join(launcher_aliases) @@ -1382,14 +1536,16 @@ def main(): #include "ck_tile/dispatcher/registry.hpp" #include "ck_tile/dispatcher/kernel_instance.hpp" #include "ck_tile/dispatcher/kernel_key.hpp" +#include "ck_tile/dispatcher/grouped_conv_registry.hpp" +#include "ck_tile/dispatcher/backends/generated_conv_backend.hpp" namespace generated {{ // Kernel launchers for direct use {launcher_section} -// Registration function -inline void {func_name}(ck_tile::dispatcher::Registry& registry, const std::string& arch) {{ +// Registration function (takes GroupedConvRegistry for conv kernels) +inline void {func_name}(ck_tile::dispatcher::GroupedConvRegistry& registry, const std::string& arch) {{ {register_body} }} @@ -1439,7 +1595,7 @@ inline void {func_name}(ck_tile::dispatcher::Registry& registry, const std::stri """ header_path.write_text(header_content) - print(f"[{target_name}] ✓ {len(obj_files)} kernels compiled") + print(f"[{target_name}] OK {len(obj_files)} kernels compiled") return 0 diff --git a/dispatcher/scripts/generate_conv_dispatch_header.py b/dispatcher/scripts/generate_conv_dispatch_header.py new file mode 100644 index 0000000000..55cc085ed9 --- /dev/null +++ b/dispatcher/scripts/generate_conv_dispatch_header.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +"""Generate the conv_python_dispatch.hpp header for the Python conv library. + +Reads the include_all headers to find available kernels and creates dispatch +aliases for 2D/3D x fwd/bwd_data/bwd_weight. +""" + +import argparse +import re +from pathlib import Path + + +def find_3d_launcher(include_all_path: Path, variant_prefix: str) -> str: + """Find first 3D launcher name from an include_all header.""" + text = include_all_path.read_text() + pattern = rf"(grouped_conv_{variant_prefix}_\w+_3d_\w+)\.hpp" + match = re.search(pattern, text) + if match: + return match.group(1) + "_Launcher" + return "" + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--kernel-dir", required=True) + parser.add_argument("--output", required=True) + args = parser.parse_args() + + kdir = Path(args.kernel_dir) + + fwd_3d = find_3d_launcher(kdir / "include_all_grouped_conv_fwd_kernels.hpp", "fwd") + bwd_data_3d = find_3d_launcher( + kdir / "include_all_grouped_conv_bwd_data_kernels.hpp", "bwd_data" + ) + bwd_weight_3d = find_3d_launcher( + kdir / "include_all_grouped_conv_bwd_weight_kernels.hpp", "bwd_weight" + ) + + lines = [ + "// Auto-generated dispatch header for Python conv library", + "#pragma once", + "", + "// Forward kernels", + '#include "include_all_grouped_conv_fwd_kernels.hpp"', + "#define CONV_FWD_2D_AVAILABLE 1", + ] + if fwd_3d: + lines += [ + "#define CONV_FWD_3D_AVAILABLE 1", + f"using ConvFwd3dLauncher = {fwd_3d};", + ] + lines += [ + "", + "// Backward data kernels", + '#include "include_all_grouped_conv_bwd_data_kernels.hpp"', + "#define CONV_BWD_DATA_2D_AVAILABLE 1", + ] + if bwd_data_3d: + lines += [ + "#define CONV_BWD_DATA_3D_AVAILABLE 1", + f"using ConvBwdData3dLauncher = {bwd_data_3d};", + ] + lines += [ + "", + "// Backward weight kernels", + '#include "include_all_grouped_conv_bwd_weight_kernels.hpp"', + "#define CONV_BWD_WEIGHT_2D_AVAILABLE 1", + ] + if bwd_weight_3d: + lines += [ + "#define CONV_BWD_WEIGHT_3D_AVAILABLE 1", + f"using ConvBwdWeight3dLauncher = {bwd_weight_3d};", + ] + + # Kernel name table for Python introspection + names = [] + if True: # fwd 2D always present + names.append('"fwd_2d"') + if fwd_3d: + names.append('"fwd_3d"') + if True: # bwd_data 2D + names.append('"bwd_data_2d"') + if bwd_data_3d: + names.append('"bwd_data_3d"') + if True: # bwd_weight 2D + names.append('"bwd_weight_2d"') + if bwd_weight_3d: + names.append('"bwd_weight_3d"') + + lines += [ + "", + "// Kernel inventory for Python", + f"static const char* CONV_KERNEL_NAMES[] = {{{', '.join(names)}}};", + f"static const int CONV_KERNEL_COUNT = {len(names)};", + "", + ] + + Path(args.output).write_text("\n".join(lines) + "\n") + print(f"Generated dispatch header: {args.output} ({len(names)} kernels)") + + +if __name__ == "__main__": + main() diff --git a/dispatcher/scripts/parallel_kernel_builder.py b/dispatcher/scripts/parallel_kernel_builder.py index 911ea61bd7..aef8f4ff0b 100755 --- a/dispatcher/scripts/parallel_kernel_builder.py +++ b/dispatcher/scripts/parallel_kernel_builder.py @@ -132,7 +132,7 @@ def main(): print(f"Linking failed: {result.stderr}") return 1 - print(f"✓ Built: {lib_path}") + print(f"OK Built: {lib_path}") return 0 diff --git a/dispatcher/scripts/stress_test_autocorrect.py b/dispatcher/scripts/stress_test_autocorrect.py index 13e92abffa..63b250071e 100644 --- a/dispatcher/scripts/stress_test_autocorrect.py +++ b/dispatcher/scripts/stress_test_autocorrect.py @@ -34,9 +34,9 @@ from compile_gemm_examples import ( # noqa: E402 validate_kernel_config, expand_declaration_with_arch_filter, ) -from compile_conv_examples import ( # noqa: E402 - validate_conv_kernel_config, - expand_conv_declaration_with_arch_filter, +from compile_grouped_conv_examples import ( # noqa: E402 + validate_grouped_conv_kernel_config as validate_conv_kernel_config, + expand_grouped_conv_declaration_with_arch_filter as expand_conv_declaration_with_arch_filter, ) @@ -316,7 +316,7 @@ def test_python_autocorrect(verbose=False): if was_modified: print(f" Modified: {len(corrections)} correction(s)") for c in corrections: - print(f" • {c}") + print(f" - {c}") except Exception as e: results["failed"] += 1 @@ -465,7 +465,7 @@ def run_stress_test(arch, num_samples, verbose): } expanded = expand_declaration_with_arch_filter(config, test_arch) - status = "✓" if expanded else "✗" + status = "OK" if expanded else "FAIL" expected = test_arch in test["expected_archs"] match = "OK" if (bool(expanded) == expected) else "MISMATCH" diff --git a/dispatcher/src/dispatcher.cpp b/dispatcher/src/dispatcher.cpp index fdb400921e..2cb589adf2 100644 --- a/dispatcher/src/dispatcher.cpp +++ b/dispatcher/src/dispatcher.cpp @@ -2,17 +2,18 @@ // SPDX-License-Identifier: MIT #include "ck_tile/dispatcher/dispatcher.hpp" -#include +#include "ck_tile/dispatcher/dispatcher_error.hpp" #include #include namespace ck_tile { namespace dispatcher { -Dispatcher::Dispatcher(Registry* registry) +Dispatcher::Dispatcher(Registry* registry, const std::string& gfx_arch) : registry_(registry ? registry : &Registry::instance()), heuristic_(nullptr), - strategy_(SelectionStrategy::FirstFit) + strategy_(SelectionStrategy::FirstFit), + gfx_arch_(gfx_arch) { } @@ -61,7 +62,7 @@ float Dispatcher::run_fused(const void* a_ptr, std::ostringstream oss; oss << "No suitable kernel found for problem: M=" << problem.M << " N=" << problem.N << " K=" << problem.K; - throw std::runtime_error(oss.str()); + throw NoKernelFound(oss.str()); } return kernel->run(a_ptr, b_ptr, c_ptr, d_ptrs, problem, stream); @@ -78,7 +79,7 @@ float Dispatcher::run_explicit(const std::string& kernel_id, auto kernel = registry_->lookup(kernel_id); if(!kernel) { - throw std::runtime_error("Kernel not found: " + kernel_id); + throw NoKernelFound("Kernel not found: " + kernel_id); } if(!kernel->supports(problem)) @@ -86,7 +87,7 @@ float Dispatcher::run_explicit(const std::string& kernel_id, std::ostringstream oss; oss << "Kernel " << kernel_id << " does not support problem: M=" << problem.M << " N=" << problem.N << " K=" << problem.K; - throw std::runtime_error(oss.str()); + throw UnsupportedProblem(oss.str()); } return kernel->run(a_ptr, b_ptr, c_ptr, d_ptrs, problem, stream); diff --git a/dispatcher/src/registry.cpp b/dispatcher/src/registry.cpp index 0d83afd613..f565885181 100644 --- a/dispatcher/src/registry.cpp +++ b/dispatcher/src/registry.cpp @@ -5,39 +5,32 @@ #include "ck_tile/dispatcher/json_export.hpp" #include "ck_tile/dispatcher/arch_filter.hpp" #include +#include +#include namespace ck_tile { namespace dispatcher { -Registry::Registry() - : name_("default"), - auto_export_enabled_(false), - auto_export_include_statistics_(true), - auto_export_on_every_registration_(true) -{ -} +Registry::Registry() = default; Registry::~Registry() { - // Perform auto-export on destruction if enabled (regardless of export_on_every_registration - // setting) if(auto_export_enabled_) { perform_auto_export(); } } -Registry::Registry(Registry&& other) noexcept - : mutex_() // mutex is not movable, create new one - , - kernels_(std::move(other.kernels_)), - name_(std::move(other.name_)), - auto_export_enabled_(other.auto_export_enabled_), - auto_export_filename_(std::move(other.auto_export_filename_)), - auto_export_include_statistics_(other.auto_export_include_statistics_), - auto_export_on_every_registration_(other.auto_export_on_every_registration_) +Registry::Registry(Registry&& other) noexcept : Base(std::move(other)) { - // Disable auto-export on the moved-from object to prevent double export + // Base move constructor already locked+released other.mutex_. + // Re-acquire to safely read the remaining fields. + std::lock_guard lock(other.mutex()); + auto_export_enabled_ = other.auto_export_enabled_; + auto_export_filename_ = std::move(other.auto_export_filename_); + auto_export_include_statistics_ = other.auto_export_include_statistics_; + auto_export_on_every_registration_ = other.auto_export_on_every_registration_; + other.auto_export_enabled_ = false; } @@ -45,11 +38,7 @@ Registry& Registry::operator=(Registry&& other) noexcept { if(this != &other) { - std::lock_guard lock(mutex_); - std::lock_guard other_lock(other.mutex_); - - kernels_ = std::move(other.kernels_); - name_ = std::move(other.name_); + Base::operator=(std::move(other)); auto_export_enabled_ = other.auto_export_enabled_; auto_export_filename_ = std::move(other.auto_export_filename_); auto_export_include_statistics_ = other.auto_export_include_statistics_; @@ -64,55 +53,27 @@ Registry& Registry::operator=(Registry&& other) noexcept bool Registry::register_kernel(KernelInstancePtr instance, Priority priority) { if(!instance) - { return false; - } - const std::string identifier = instance->get_key().encode_identifier(); - - bool registered = false; + if(Base::register_kernel(instance->get_name(), instance, priority)) { - std::lock_guard lock(mutex_); - - auto it = kernels_.find(identifier); - if(it != kernels_.end()) + if(auto_export_enabled_ && auto_export_on_every_registration_) { - // Kernel with this identifier already exists - // Only replace if new priority is higher - if(priority > it->second.priority) - { - it->second.instance = instance; - it->second.priority = priority; - registered = true; - } - } - else - { - // New kernel, insert it - kernels_[identifier] = RegistryEntry{instance, priority}; - registered = true; + perform_auto_export(); } + return true; } - - // Perform auto-export if enabled and configured to export on every registration - if(registered && auto_export_enabled_ && auto_export_on_every_registration_) - { - perform_auto_export(); - } - - return registered; + return false; } KernelInstancePtr Registry::lookup(const std::string& identifier) const { - std::lock_guard lock(mutex_); - - auto it = kernels_.find(identifier); - if(it != kernels_.end()) + std::lock_guard lock(mutex()); + auto it = entries().find(identifier); + if(it != entries().end()) { return it->second.instance; } - return nullptr; } @@ -121,75 +82,23 @@ KernelInstancePtr Registry::lookup(const KernelKey& key) const return lookup(key.encode_identifier()); } -std::vector Registry::get_all() const -{ - std::lock_guard lock(mutex_); - - std::vector result; - result.reserve(kernels_.size()); - - for(const auto& pair : kernels_) - { - result.push_back(pair.second.instance); - } - - return result; -} +std::vector Registry::get_all() const { return Base::get_all_instances(); } std::vector Registry::filter(std::function predicate) const { - std::lock_guard lock(mutex_); - + std::lock_guard lock(mutex()); std::vector result; - - for(const auto& pair : kernels_) + for(const auto& [name, entry] : entries()) { - if(predicate(*pair.second.instance)) + if(predicate(*(entry.instance))) { - result.push_back(pair.second.instance); + result.push_back(entry.instance); } } - return result; } -std::size_t Registry::size() const -{ - std::lock_guard lock(mutex_); - return kernels_.size(); -} - -bool Registry::empty() const -{ - std::lock_guard lock(mutex_); - return kernels_.empty(); -} - -void Registry::clear() -{ - std::lock_guard lock(mutex_); - kernels_.clear(); -} - -const std::string& Registry::get_name() const -{ - std::lock_guard lock(mutex_); - return name_; -} - -void Registry::set_name(const std::string& name) -{ - std::lock_guard lock(mutex_); - name_ = name; -} - -Registry& Registry::instance() -{ - static Registry global_registry; - return global_registry; -} - std::string Registry::export_json(bool include_statistics) const { return export_registry_json(*this, include_statistics); @@ -204,7 +113,7 @@ void Registry::enable_auto_export(const std::string& filename, bool include_statistics, bool export_on_every_registration) { - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex()); auto_export_enabled_ = true; auto_export_filename_ = filename; auto_export_include_statistics_ = include_statistics; @@ -213,13 +122,13 @@ void Registry::enable_auto_export(const std::string& filename, void Registry::disable_auto_export() { - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex()); auto_export_enabled_ = false; } bool Registry::is_auto_export_enabled() const { - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex()); return auto_export_enabled_; } @@ -230,7 +139,7 @@ void Registry::perform_auto_export() bool include_stats; { - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex()); if(!auto_export_enabled_) { return; @@ -243,31 +152,15 @@ void Registry::perform_auto_export() export_json_to_file(filename, include_stats); } -std::size_t Registry::merge_from(const Registry& other, Priority priority) -{ - auto other_kernels = other.get_all(); - std::size_t merged_count = 0; - - for(const auto& kernel : other_kernels) - { - if(register_kernel(kernel, priority)) - { - merged_count++; - } - } - - return merged_count; -} - std::size_t Registry::filter_by_arch(const std::string& gpu_arch) { ArchFilter filter(gpu_arch); std::vector to_remove; { - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex()); - for(const auto& pair : kernels_) + for(const auto& pair : entries()) { if(!filter.is_valid(pair.second.instance->get_key())) { @@ -277,12 +170,18 @@ std::size_t Registry::filter_by_arch(const std::string& gpu_arch) for(const auto& key : to_remove) { - kernels_.erase(key); + entries_mut().erase(key); } } return to_remove.size(); } +Registry& Registry::instance() +{ + static Registry global_registry; + return global_registry; +} + } // namespace dispatcher -} // namespace ck_tile +} // namespace ck_tile \ No newline at end of file diff --git a/dispatcher/tests/CMakeLists.txt b/dispatcher/tests/CMakeLists.txt index 6c20c18c95..a54feba284 100644 --- a/dispatcher/tests/CMakeLists.txt +++ b/dispatcher/tests/CMakeLists.txt @@ -217,6 +217,10 @@ endforeach() # Standalone integration tests (with their own main()) set(STANDALONE_TESTS test_minimal.cpp + test_grouped_conv_config.cpp + test_grouped_conv_problem.cpp + test_grouped_conv_kernel_decl.cpp + test_grouped_conv_registry.cpp ) foreach(test_source ${STANDALONE_TESTS}) diff --git a/dispatcher/tests/test_autocorrect.py b/dispatcher/tests/test_autocorrect.py index 0ec3ebda3c..3f52049f74 100644 --- a/dispatcher/tests/test_autocorrect.py +++ b/dispatcher/tests/test_autocorrect.py @@ -42,10 +42,10 @@ from compile_gemm_examples import ( # noqa: E402 expand_declaration_with_arch_filter, is_wildcard_declaration, ) -from compile_conv_examples import ( # noqa: E402 - validate_conv_kernel_config, - expand_conv_declaration_with_arch_filter, - is_conv_wildcard_declaration, +from compile_grouped_conv_examples import ( # noqa: E402 + validate_grouped_conv_kernel_config as validate_conv_kernel_config, + expand_grouped_conv_declaration_with_arch_filter as expand_conv_declaration_with_arch_filter, + is_grouped_conv_wildcard_declaration as is_conv_wildcard_declaration, ) from ctypes_utils import auto_correct_kernel_config, KernelConfig # noqa: E402 diff --git a/dispatcher/tests/test_codegen_common.py b/dispatcher/tests/test_codegen_common.py new file mode 100644 index 0000000000..2efeaefb4d --- /dev/null +++ b/dispatcher/tests/test_codegen_common.py @@ -0,0 +1,244 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Tests for codegen/codegen_common.py -- shared infrastructure for GEMM and grouped conv codegen. + +Phase 1a TDD: these tests are written BEFORE the implementation exists. +Run: python3 -m pytest tests/test_codegen_common.py -v +""" + +import sys +import unittest +from pathlib import Path + +SCRIPT_DIR = Path(__file__).parent.resolve() +DISPATCHER_DIR = SCRIPT_DIR.parent +sys.path.insert(0, str(DISPATCHER_DIR / "codegen")) + +from codegen_common import ( # noqa: E402 + TileConfig, + TraitConfigBase, + CommonTypeMappings, + generate_cpp_compilation_unit, + parallel_generate, + valid_wave_configs, + valid_warp_configs, + valid_trait_configs, + needs_wave_expansion, + needs_warp_expansion, + needs_pipeline_expansion, +) + + +class TestTileConfig(unittest.TestCase): + """TileConfig dataclass tests.""" + + def test_valid_config(self): + tc = TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16) + self.assertTrue(tc.is_valid()) + + def test_zero_tile_invalid(self): + tc = TileConfig(0, 128, 32, 2, 2, 1, 32, 32, 16) + self.assertFalse(tc.is_valid()) + + def test_non_divisible_invalid(self): + tc = TileConfig(127, 128, 32, 2, 2, 1, 32, 32, 16) + self.assertFalse(tc.is_valid()) + + def test_all_fields_accessible(self): + tc = TileConfig(256, 128, 64, 4, 1, 1, 32, 32, 16) + self.assertEqual(tc.tile_m, 256) + self.assertEqual(tc.tile_n, 128) + self.assertEqual(tc.tile_k, 64) + self.assertEqual(tc.warp_m, 4) + self.assertEqual(tc.warp_n, 1) + self.assertEqual(tc.warp_k, 1) + self.assertEqual(tc.warp_tile_m, 32) + self.assertEqual(tc.warp_tile_n, 32) + self.assertEqual(tc.warp_tile_k, 16) + + def test_small_valid_config(self): + tc = TileConfig(16, 16, 16, 1, 1, 1, 16, 16, 16) + self.assertTrue(tc.is_valid()) + + +class TestTraitConfigBase(unittest.TestCase): + """TraitConfigBase dataclass tests.""" + + def test_valid_intrawave(self): + tc = TraitConfigBase("compv3", "cshuffle", "intrawave", False, False, False) + self.assertTrue(tc.is_valid()) + + def test_invalid_interwave_compv3(self): + tc = TraitConfigBase("compv3", "cshuffle", "interwave", False, False, False) + self.assertFalse(tc.is_valid()) + + def test_invalid_interwave_compv4(self): + tc = TraitConfigBase("compv4", "cshuffle", "interwave", False, False, False) + self.assertFalse(tc.is_valid()) + + def test_valid_mem_interwave(self): + tc = TraitConfigBase("mem", "cshuffle", "interwave", False, False, False) + self.assertTrue(tc.is_valid()) + + def test_valid_mem_intrawave(self): + tc = TraitConfigBase("mem", "cshuffle", "intrawave", False, False, False) + self.assertTrue(tc.is_valid()) + + def test_padding_fields(self): + tc = TraitConfigBase("compv3", "cshuffle", "intrawave", True, True, True) + self.assertTrue(tc.pad_m) + self.assertTrue(tc.pad_n) + self.assertTrue(tc.pad_k) + + +class TestCommonTypeMappings(unittest.TestCase): + """CommonTypeMappings tests.""" + + def test_dtype_to_ck(self): + self.assertEqual(CommonTypeMappings.DTYPE_TO_CK["fp16"], "fp16_t") + self.assertEqual(CommonTypeMappings.DTYPE_TO_CK["bf16"], "bf16_t") + self.assertEqual(CommonTypeMappings.DTYPE_TO_CK["fp32"], "float") + self.assertEqual(CommonTypeMappings.DTYPE_TO_CK["fp8"], "fp8_t") + + def test_pipeline_to_ck(self): + self.assertEqual( + CommonTypeMappings.PIPELINE_TO_CK["mem"], "GemmPipelineAgBgCrMem" + ) + self.assertIn("compv3", CommonTypeMappings.PIPELINE_TO_CK) + self.assertIn("compv4", CommonTypeMappings.PIPELINE_TO_CK) + + def test_pipeline_to_base(self): + self.assertIn("mem", CommonTypeMappings.PIPELINE_TO_BASE) + self.assertIn("compv3", CommonTypeMappings.PIPELINE_TO_BASE) + self.assertIn("compv4", CommonTypeMappings.PIPELINE_TO_BASE) + + def test_scheduler_to_ck(self): + self.assertIn("intrawave", CommonTypeMappings.SCHEDULER_TO_CK) + self.assertIn("interwave", CommonTypeMappings.SCHEDULER_TO_CK) + + def test_epilogue_to_dispatcher(self): + self.assertIn("cshuffle", CommonTypeMappings.EPILOGUE_TO_DISPATCHER) + self.assertIn("default", CommonTypeMappings.EPILOGUE_TO_DISPATCHER) + + def test_layout_to_ck(self): + self.assertIn("r", CommonTypeMappings.LAYOUT_TO_CK) + self.assertIn("c", CommonTypeMappings.LAYOUT_TO_CK) + + def test_get_output_dtype(self): + self.assertEqual(CommonTypeMappings.get_output_dtype("fp8"), "fp16") + self.assertEqual(CommonTypeMappings.get_output_dtype("bf8"), "fp16") + self.assertEqual(CommonTypeMappings.get_output_dtype("fp16"), "fp16") + self.assertEqual(CommonTypeMappings.get_output_dtype("fp32"), "fp32") + + +class TestGenerateCppCompilationUnit(unittest.TestCase): + """Tests for generate_cpp_compilation_unit.""" + + def test_includes_kernel_header(self): + result = generate_cpp_compilation_unit("my_kernel") + self.assertIn('#include "my_kernel.hpp"', result) + + def test_contains_pragma_once_or_guard(self): + result = generate_cpp_compilation_unit("test_kernel") + self.assertIn("test_kernel", result) + + def test_different_names_different_output(self): + a = generate_cpp_compilation_unit("kernel_a") + b = generate_cpp_compilation_unit("kernel_b") + self.assertNotEqual(a, b) + + +class TestParallelGenerate(unittest.TestCase): + """Tests for parallel_generate helper.""" + + def _dummy_generate(self, item): + return f"generated_{item}" + + def test_parallel_returns_all(self): + items = ["a", "b", "c", "d"] + results = parallel_generate(self._dummy_generate, items, parallel=True) + self.assertEqual(len(results), 4) + for item in items: + self.assertIn(f"generated_{item}", results) + + def test_sequential_returns_all(self): + items = ["x", "y", "z"] + results = parallel_generate(self._dummy_generate, items, parallel=False) + self.assertEqual(len(results), 3) + for item in items: + self.assertIn(f"generated_{item}", results) + + def test_empty_items(self): + results = parallel_generate(self._dummy_generate, [], parallel=True) + self.assertEqual(len(results), 0) + + def test_logs_per_kernel_progress(self): + items = ["k1", "k2"] + with self.assertLogs(level="INFO") as cm: + parallel_generate(self._dummy_generate, items, parallel=False) + log_output = "\n".join(cm.output) + self.assertIn("k1", log_output) + self.assertIn("k2", log_output) + + +class TestArchAwareExpansion(unittest.TestCase): + """Tests for arch-aware expansion helpers (best-of-conv).""" + + def test_valid_wave_configs_gfx942(self): + configs = valid_wave_configs("gfx942") + self.assertIsInstance(configs, list) + self.assertIn([2, 2, 1], configs) + self.assertIn([1, 4, 1], configs) + + def test_valid_wave_configs_unknown_arch(self): + configs = valid_wave_configs("gfx_unknown") + self.assertIsInstance(configs, list) + self.assertGreater(len(configs), 0) + + def test_valid_warp_configs_gfx942_fp16(self): + configs = valid_warp_configs("gfx942", "fp16") + self.assertIsInstance(configs, list) + self.assertIn([32, 32, 16], configs) + + def test_valid_warp_configs_unknown_arch(self): + configs = valid_warp_configs("gfx_unknown", "fp16") + self.assertIsInstance(configs, list) + self.assertGreater(len(configs), 0) + + def test_valid_trait_configs_excludes_interwave_compute(self): + configs = valid_trait_configs() + self.assertIsInstance(configs, list) + self.assertNotIn(("compv3", "cshuffle", "interwave"), configs) + self.assertNotIn(("compv4", "cshuffle", "interwave"), configs) + + def test_valid_trait_configs_includes_mem_interwave(self): + configs = valid_trait_configs() + has_mem_interwave = any(p == "mem" and s == "interwave" for p, s in configs) + self.assertTrue(has_mem_interwave) + + def test_needs_wave_expansion_wildcard(self): + self.assertTrue(needs_wave_expansion({"wave_m": -1, "wave_n": 2})) + self.assertTrue(needs_wave_expansion({"wave_m": 2, "wave_n": -1})) + + def test_needs_wave_expansion_explicit(self): + self.assertFalse(needs_wave_expansion({"wave_m": 2, "wave_n": 2})) + + def test_needs_warp_expansion_wildcard(self): + self.assertTrue(needs_warp_expansion({"warp_m": -1, "warp_n": 32})) + + def test_needs_warp_expansion_explicit(self): + self.assertFalse(needs_warp_expansion({"warp_m": 32, "warp_n": 32})) + + def test_needs_pipeline_expansion_wildcard(self): + self.assertTrue(needs_pipeline_expansion({"pipeline": "*"})) + + def test_needs_pipeline_expansion_explicit(self): + self.assertFalse(needs_pipeline_expansion({"pipeline": "compv4"})) + + +if __name__ == "__main__": + unittest.main() diff --git a/dispatcher/tests/test_dispatcher_common.py b/dispatcher/tests/test_dispatcher_common.py new file mode 100644 index 0000000000..2c0fc8307c --- /dev/null +++ b/dispatcher/tests/test_dispatcher_common.py @@ -0,0 +1,243 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +Tests for python/dispatcher_common.py -- shared Python dispatcher utilities. + +Phase 1b TDD: tests written BEFORE implementation exists. +Run: python3 -m pytest tests/test_dispatcher_common.py -v +""" + +import io +import sys +import unittest +from pathlib import Path +from unittest.mock import patch + +SCRIPT_DIR = Path(__file__).parent.resolve() +DISPATCHER_DIR = SCRIPT_DIR.parent +sys.path.insert(0, str(DISPATCHER_DIR / "python")) +sys.path.insert(0, str(DISPATCHER_DIR / "codegen")) + +from dispatcher_common import ( # noqa: E402 + get_dispatcher_root, + get_ck_root, + get_build_dir, + get_generated_kernels_dir, + get_arch_filter_data, + ValidationResultBase, + validate_wave_config, + validate_warp_tile_config, + validate_trait_combo, + auto_correct_wave, + auto_correct_trait, + Colors, + print_phase, + print_success, + print_error, + print_info, + cleanup_generated_kernels, +) + + +class TestPathHelpers(unittest.TestCase): + """Tests for path helper functions.""" + + def test_dispatcher_root_contains_codegen(self): + root = get_dispatcher_root() + self.assertTrue((root / "codegen").exists()) + + def test_ck_root_contains_include_or_is_parent(self): + root = get_ck_root() + self.assertTrue(root.exists()) + self.assertEqual(root, get_dispatcher_root().parent) + + def test_build_dir_is_under_dispatcher(self): + build = get_build_dir() + self.assertEqual(build.parent, get_dispatcher_root()) + + def test_generated_kernels_dir_under_build(self): + gen_dir = get_generated_kernels_dir() + self.assertEqual(gen_dir.parent, get_build_dir()) + + +class TestGetArchFilterData(unittest.TestCase): + """Tests for get_arch_filter_data.""" + + def test_returns_dict(self): + data = get_arch_filter_data() + self.assertIsInstance(data, dict) + + def test_has_warp_combos(self): + data = get_arch_filter_data() + self.assertIn("warp_combos", data) + + def test_has_warp_tile_combos(self): + data = get_arch_filter_data() + self.assertIn("warp_tile_combos", data) + + def test_has_trait_unsupported(self): + data = get_arch_filter_data() + self.assertIn("trait_unsupported", data) + + def test_has_supported_archs(self): + data = get_arch_filter_data() + self.assertIn("supported_archs", data) + self.assertIn("gfx942", data["supported_archs"]) + + def test_gfx942_wave_configs(self): + data = get_arch_filter_data() + gfx942 = data["warp_combos"].get("gfx942", []) + self.assertIn([2, 2, 1], gfx942) + + +class TestValidationResultBase(unittest.TestCase): + """Tests for ValidationResultBase dataclass.""" + + def test_valid_result(self): + vr = ValidationResultBase(is_valid=True) + self.assertTrue(vr.is_valid) + self.assertEqual(vr.errors, []) + self.assertEqual(vr.warnings, []) + self.assertEqual(vr.suggested_fixes, {}) + + def test_invalid_result(self): + vr = ValidationResultBase( + is_valid=False, + errors=["bad wave"], + suggested_fixes={"wave_m": 2}, + ) + self.assertFalse(vr.is_valid) + self.assertEqual(len(vr.errors), 1) + self.assertIn("wave_m", vr.suggested_fixes) + + +class TestValidateWaveConfig(unittest.TestCase): + """Tests for validate_wave_config.""" + + def test_valid_wave(self): + is_valid, msg = validate_wave_config([2, 2, 1], "gfx942") + self.assertTrue(is_valid) + self.assertEqual(msg, "") + + def test_invalid_wave(self): + is_valid, msg = validate_wave_config([3, 3, 1], "gfx942") + self.assertFalse(is_valid) + self.assertIn("wave", msg.lower()) + + +class TestValidateWarpTileConfig(unittest.TestCase): + """Tests for validate_warp_tile_config.""" + + def test_valid_warp_tile(self): + is_valid, msg = validate_warp_tile_config([32, 32, 16], "gfx942", "fp16") + self.assertTrue(is_valid) + + def test_invalid_warp_tile(self): + is_valid, msg = validate_warp_tile_config([99, 99, 99], "gfx942", "fp16") + self.assertFalse(is_valid) + self.assertIn("warp", msg.lower()) + + +class TestValidateTraitCombo(unittest.TestCase): + """Tests for validate_trait_combo.""" + + def test_valid_trait(self): + is_valid, msg = validate_trait_combo("compv3", "cshuffle", "intrawave") + self.assertTrue(is_valid) + + def test_invalid_trait_interwave_compute(self): + is_valid, msg = validate_trait_combo("compv4", "cshuffle", "interwave") + self.assertFalse(is_valid) + + def test_valid_mem_interwave(self): + is_valid, msg = validate_trait_combo("mem", "cshuffle", "interwave") + self.assertTrue(is_valid) + + +class TestAutoCorrectWave(unittest.TestCase): + """Tests for auto_correct_wave.""" + + def test_corrects_invalid_wave(self): + corrected = auto_correct_wave([1, 1, 1], "gfx942") + self.assertIsInstance(corrected, list) + self.assertEqual(len(corrected), 3) + data = get_arch_filter_data() + valid_waves = data["warp_combos"].get("gfx942", [[2, 2, 1]]) + self.assertIn(corrected, valid_waves) + + +class TestAutoCorrectTrait(unittest.TestCase): + """Tests for auto_correct_trait.""" + + def test_corrects_invalid_scheduler(self): + corrected_pipeline, corrected_scheduler = auto_correct_trait( + "compv4", "interwave" + ) + self.assertEqual(corrected_scheduler, "intrawave") + + +class TestColors(unittest.TestCase): + """Tests for Colors class (cross-platform ANSI support from conv).""" + + def test_green_returns_string(self): + result = Colors.green("ok") + self.assertIn("ok", result) + + def test_red_returns_string(self): + result = Colors.red("error") + self.assertIn("error", result) + + def test_yellow_returns_string(self): + result = Colors.yellow("warn") + self.assertIn("warn", result) + + def test_bold_returns_string(self): + result = Colors.bold("title") + self.assertIn("title", result) + + def test_plain_mode_no_ansi(self): + with patch.object(Colors, "_use_color", return_value=False): + result = Colors.green("plain") + self.assertEqual(result, "plain") + + +class TestPhasedOutput(unittest.TestCase): + """Tests for phased output helpers.""" + + def test_print_phase(self): + buf = io.StringIO() + with patch("sys.stdout", buf): + print_phase(1, "Setup") + self.assertIn("Setup", buf.getvalue()) + + def test_print_success(self): + buf = io.StringIO() + with patch("sys.stdout", buf): + print_success("Done") + self.assertIn("Done", buf.getvalue()) + + def test_print_error(self): + buf = io.StringIO() + with patch("sys.stdout", buf): + print_error("Oops") + self.assertIn("Oops", buf.getvalue()) + + def test_print_info(self): + buf = io.StringIO() + with patch("sys.stdout", buf): + print_info("FYI") + self.assertIn("FYI", buf.getvalue()) + + +class TestCleanup(unittest.TestCase): + """Tests for cleanup_generated_kernels.""" + + def test_cleanup_nonexistent_dir_no_error(self): + cleanup_generated_kernels(Path("/tmp/nonexistent_ck_test_dir_12345")) + + +if __name__ == "__main__": + unittest.main() diff --git a/dispatcher/tests/test_examples_integration.py b/dispatcher/tests/test_examples_integration.py index cfd18a3305..d02ea69787 100644 --- a/dispatcher/tests/test_examples_integration.py +++ b/dispatcher/tests/test_examples_integration.py @@ -28,14 +28,18 @@ sys.path.insert(0, str(PYTHON_DIR)) def run_python_example( - example_path: Path, timeout: int = 120 + example_path: Path, timeout: int = 120, extra_args: list = None ) -> subprocess.CompletedProcess: """Run a Python example and capture output.""" env = os.environ.copy() env["PYTHONPATH"] = str(PYTHON_DIR) + cmd = [sys.executable, str(example_path)] + if extra_args: + cmd.extend(extra_args) + return subprocess.run( - [sys.executable, str(example_path)], + cmd, capture_output=True, text=True, timeout=timeout, @@ -111,61 +115,74 @@ class TestGemmPythonExamples(unittest.TestCase): result = run_python_example(example) self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") - # Should pass validation self.assertIn("PASS", result.stdout.upper(), "Validation should pass") class TestConvPythonExamples(unittest.TestCase): - """Test Conv Python examples.""" + """Test grouped conv Python examples.""" @classmethod def setUpClass(cls): """Check if examples directory exists.""" - cls.conv_examples_dir = EXAMPLES_DIR / "conv" / "python" + cls.conv_examples_dir = EXAMPLES_DIR / "grouped_conv" / "python" if not cls.conv_examples_dir.exists(): - raise unittest.SkipTest("Conv Python examples not found") + raise unittest.SkipTest("Grouped conv Python examples not found") - def test_01_basic_conv(self): - """Test basic conv example.""" - example = self.conv_examples_dir / "01_basic_conv.py" + def test_01_basic_grouped_conv(self): + """Test basic grouped conv example.""" + example = self.conv_examples_dir / "01_basic_grouped_conv.py" if not example.exists(): self.skipTest(f"{example.name} not found") - result = run_python_example(example) - self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") - self.assertIn("TFLOPS", result.stdout, "Should report TFLOPS") + self.assertIn("PASS", result.stdout.upper()) - def test_02_conv2d_fwd(self): - """Test 2D forward conv example.""" - example = self.conv_examples_dir / "02_conv2d_fwd.py" + def test_02_forward(self): + """Test forward conv example (2D + 3D).""" + example = self.conv_examples_dir / "02_forward.py" if not example.exists(): self.skipTest(f"{example.name} not found") - result = run_python_example(example) - self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + self.assertIn("PASS", result.stdout.upper()) - def test_03_conv3d_fwd(self): - """Test 3D forward conv example.""" - example = self.conv_examples_dir / "03_conv3d_fwd.py" + def test_03_bwd_data(self): + """Test backward data example.""" + example = self.conv_examples_dir / "03_bwd_data.py" if not example.exists(): self.skipTest(f"{example.name} not found") - result = run_python_example(example) - self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + self.assertIn("PASS", result.stdout.upper()) - def test_07_validation(self): - """Test validation example.""" - example = self.conv_examples_dir / "07_validation.py" + def test_04_bwd_weight(self): + """Test backward weight example.""" + example = self.conv_examples_dir / "04_bwd_weight.py" if not example.exists(): self.skipTest(f"{example.name} not found") - result = run_python_example(example) - self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") - self.assertIn("PASS", result.stdout.upper(), "Validation should pass") + self.assertIn("PASS", result.stdout.upper()) + + def test_05_benchmark(self): + """Test benchmark example.""" + example = self.conv_examples_dir / "05_benchmark.py" + if not example.exists(): + self.skipTest(f"{example.name} not found") + result = run_python_example( + example, extra_args=["--warmup", "1", "--repeat", "1"] + ) + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + self.assertIn("PASS", result.stdout.upper()) + + def test_06_registry_json(self): + """Test registry + heuristic + JSON example.""" + example = self.conv_examples_dir / "06_registry_json.py" + if not example.exists(): + self.skipTest(f"{example.name} not found") + result = run_python_example(example) + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + self.assertIn("PASS", result.stdout.upper()) class TestGemmCppExamples(unittest.TestCase): @@ -195,18 +212,18 @@ class TestGemmCppExamples(unittest.TestCase): self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") - def test_gemm_04_validation(self): - """Test validation GEMM C++ example.""" - result = run_cpp_example("gemm_04_validation") + def test_gemm_03_benchmark_validation(self): + """Test benchmark+validation GEMM C++ example.""" + result = run_cpp_example("gemm_03_benchmark_validation") if result is None: - self.skipTest("gemm_04_validation not built") + self.skipTest("gemm_03_benchmark_validation not built") self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") self.assertIn("PASS", result.stdout.upper(), "Validation should pass") class TestConvCppExamples(unittest.TestCase): - """Test Conv C++ examples.""" + """Test grouped conv C++ examples.""" @classmethod def setUpClass(cls): @@ -215,23 +232,29 @@ class TestConvCppExamples(unittest.TestCase): if not cls.examples_dir.exists(): raise unittest.SkipTest("C++ examples not built") - def test_conv_01_forward(self): - """Test forward conv C++ example.""" - result = run_cpp_example("conv_01_forward") + def test_grouped_conv_01_basic(self): + """Test basic grouped conv C++ example.""" + result = run_cpp_example("grouped_conv_01_basic") if result is None: - self.skipTest("conv_01_forward not built") - + self.skipTest("grouped_conv_01_basic not built") self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") - self.assertIn("TFLOPS", result.stdout, "Should report TFLOPS") + self.assertIn("PASS", result.stdout.upper()) - def test_conv_02_validation(self): - """Test validation conv C++ example.""" - result = run_cpp_example("conv_02_validation") + def test_grouped_conv_02_all_dirs(self): + """Test all directions grouped conv C++ example.""" + result = run_cpp_example("grouped_conv_02_all_dirs") if result is None: - self.skipTest("conv_02_validation not built") - + self.skipTest("grouped_conv_02_all_dirs not built") self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") - self.assertIn("PASS", result.stdout.upper(), "Validation should pass") + self.assertIn("PASS", result.stdout.upper()) + + def test_grouped_conv_03_bench_val(self): + """Test benchmark+validation grouped conv C++ example.""" + result = run_cpp_example("grouped_conv_03_bench_val") + if result is None: + self.skipTest("grouped_conv_03_bench_val not built") + self.assertEqual(result.returncode, 0, f"Example failed:\n{result.stderr}") + self.assertIn("PASS", result.stdout.upper()) class TestUtilityImports(unittest.TestCase): @@ -246,14 +269,18 @@ class TestUtilityImports(unittest.TestCase): except ImportError as e: self.fail(f"Failed to import ctypes_utils: {e}") - def test_import_conv_utils(self): - """Test importing conv_utils.""" + def test_import_grouped_conv_utils(self): + """Test importing grouped_conv_utils.""" try: - from conv_utils import ConvSignature, ConvAlgorithm, ConvProblem # noqa: F401 + from grouped_conv_utils import ( # noqa: F401 + GroupedConvValidationResult, + validate_grouped_conv_config, + GroupedConvDataType, + ) self.assertTrue(True) except ImportError as e: - self.fail(f"Failed to import conv_utils: {e}") + self.fail(f"Failed to import grouped_conv_utils: {e}") def test_kernel_config_creation(self): """Test creating a KernelConfig.""" @@ -272,22 +299,19 @@ class TestUtilityImports(unittest.TestCase): self.assertEqual(config.dtype_a, "fp16") self.assertEqual(config.layout_a, "row") - def test_conv_signature_creation(self): - """Test creating a ConvSignature.""" - from conv_utils import ConvSignature + def test_grouped_conv_default_config(self): + """Test creating a grouped conv default config.""" + from grouped_conv_utils import get_grouped_conv_default_config - sig = ConvSignature( - dtype_in="fp16", - dtype_wei="fp16", - dtype_out="fp16", - dtype_acc="fp32", - layout="nhwgc", - direction="forward", - num_dims=2, + config = get_grouped_conv_default_config( + variant="forward", + ndim_spatial=2, + arch="gfx942", ) - self.assertEqual(sig.dtype_in, "fp16") - self.assertEqual(sig.direction, "forward") + d = config.to_dict() if hasattr(config, "to_dict") else config + self.assertEqual(d["variant"], "forward") + self.assertEqual(d["arch"], "gfx942") class TestAutoCorrection(unittest.TestCase): @@ -316,21 +340,22 @@ class TestAutoCorrection(unittest.TestCase): self.assertTrue(was_modified, "Config should be modified") self.assertGreater(len(corrections), 0, "Should have corrections") - def test_conv_auto_correct(self): - """Test Conv auto-correction.""" - from conv_utils import auto_correct_conv_config - - # Call with invalid wave config parameters - corrected, was_modified, corrections = auto_correct_conv_config( - wave_m=99, # Invalid - wave_n=99, # Invalid - wave_k=99, # Invalid - dtype="fp16", - arch="gfx942", + def test_grouped_conv_auto_correct(self): + """Test Grouped Conv auto-correction.""" + from grouped_conv_utils import ( + auto_correct_grouped_conv_config, + get_grouped_conv_default_config, ) - self.assertTrue(was_modified, "Config should be modified") - self.assertGreater(len(corrections), 0, "Should have corrections") + config = get_grouped_conv_default_config() + d = config.to_dict() if hasattr(config, "to_dict") else config + d["tile_config"]["warp_m"] = [99] + d["tile_config"]["warp_n"] = [99] + + corrected, result = auto_correct_grouped_conv_config(d) + + self.assertIsInstance(corrected, dict) + self.assertIn("tile_config", corrected) if __name__ == "__main__": diff --git a/dispatcher/tests/test_grouped_conv_codegen.py b/dispatcher/tests/test_grouped_conv_codegen.py new file mode 100644 index 0000000000..acfa5abd8f --- /dev/null +++ b/dispatcher/tests/test_grouped_conv_codegen.py @@ -0,0 +1,589 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +TDD tests for codegen/unified_grouped_conv_codegen.py -- grouped convolution code generator. + +These tests are written BEFORE the implementation exists. +Run: python3 -m pytest dispatcher/tests/test_grouped_conv_codegen.py -v +""" + +import sys +import unittest +from pathlib import Path +from unittest.mock import patch + +SCRIPT_DIR = Path(__file__).parent.resolve() +DISPATCHER_DIR = SCRIPT_DIR.parent +sys.path.insert(0, str(DISPATCHER_DIR / "codegen")) +sys.path.insert(0, str(DISPATCHER_DIR / "python")) + +from codegen_common import TileConfig, TraitConfigBase # noqa: E402 + +from unified_grouped_conv_codegen import ( # noqa: E402 + GroupedConvVariant, + GroupedConvLayout, + GroupedConvKernelConfig, + GroupedConvTypeMappings, + GroupedConvTraitConfig, + CKTileGroupedConvKernelGenerator, + GroupedConvDispatcherWrapperGenerator, + UnifiedGroupedConvCodegen, +) + + +# ============================================================================= +# TestGroupedConvVariant +# ============================================================================= + + +class TestGroupedConvVariant(unittest.TestCase): + """Test GroupedConvVariant enum values.""" + + def test_forward_value(self): + self.assertEqual(GroupedConvVariant.FORWARD.value, "forward") + + def test_backward_data_value(self): + self.assertEqual(GroupedConvVariant.BACKWARD_DATA.value, "bwd_data") + + def test_backward_weight_value(self): + self.assertEqual(GroupedConvVariant.BACKWARD_WEIGHT.value, "bwd_weight") + + def test_all_variants_exist(self): + self.assertIn(GroupedConvVariant.FORWARD, GroupedConvVariant) + self.assertIn(GroupedConvVariant.BACKWARD_DATA, GroupedConvVariant) + self.assertIn(GroupedConvVariant.BACKWARD_WEIGHT, GroupedConvVariant) + + +# ============================================================================= +# TestGroupedConvLayout +# ============================================================================= + + +class TestGroupedConvLayout(unittest.TestCase): + """Test GroupedConvLayout enum for 1D/2D/3D layouts.""" + + def test_nhwgc_value(self): + self.assertEqual(GroupedConvLayout.NHWGC.value, "NHWGC") + + def test_gkyxc_value(self): + self.assertEqual(GroupedConvLayout.GKYXC.value, "GKYXC") + + def test_nhwgk_value(self): + self.assertEqual(GroupedConvLayout.NHWGK.value, "NHWGK") + + def test_1d_layouts_exist(self): + """1D conv layouts (e.g., NWGC, GYXC, NWGK).""" + layouts_1d = [ + lay + for lay in GroupedConvLayout + if "W" in lay.value and "H" not in lay.value + ] + self.assertGreater(len(layouts_1d), 0) + + def test_2d_layouts_exist(self): + """2D conv layouts (e.g., NHWGC, GKYXC, NHWGK).""" + layouts_2d = [lay for lay in GroupedConvLayout if "HW" in lay.value] + self.assertGreater(len(layouts_2d), 0) + + def test_3d_layouts_exist(self): + """3D conv layouts (e.g., NDHWGC, GDKYXC).""" + layouts_3d = [ + lay for lay in GroupedConvLayout if "D" in lay.value or "DHW" in lay.value + ] + self.assertGreater(len(layouts_3d), 0) + + +# ============================================================================= +# TestGroupedConvKernelConfig +# ============================================================================= + + +class TestGroupedConvKernelConfig(unittest.TestCase): + """Test GroupedConvKernelConfig dataclass.""" + + def _make_tile(self): + return TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16) + + def _make_trait(self): + return GroupedConvTraitConfig( + "mem", + "cshuffle", + "intrawave", + False, + False, + False, + double_smem_buffer=False, + num_groups_to_merge=1, + ) + + def test_name_contains_grouped_conv_fwd(self): + config = GroupedConvKernelConfig( + tile=self._make_tile(), + trait=self._make_trait(), + variant=GroupedConvVariant.FORWARD, + ndim_spatial=2, + arch="gfx942", + layout=GroupedConvLayout.NHWGC, + vector_sizes=(4, 4, 4), + ) + name = config.name("fp16") + self.assertIn("grouped_conv_fwd", name) + + def test_name_backward_data_contains_bwd_data(self): + config = GroupedConvKernelConfig( + tile=self._make_tile(), + trait=self._make_trait(), + variant=GroupedConvVariant.BACKWARD_DATA, + ndim_spatial=2, + arch="gfx942", + layout=GroupedConvLayout.NHWGC, + vector_sizes=(4, 4, 4), + ) + name = config.name("fp16") + self.assertIn("bwd_data", name) + + def test_is_valid_for_arch_supported(self): + config = GroupedConvKernelConfig( + tile=self._make_tile(), + trait=self._make_trait(), + variant=GroupedConvVariant.FORWARD, + ndim_spatial=2, + arch="gfx942", + layout=GroupedConvLayout.NHWGC, + vector_sizes=(4, 4, 4), + ) + self.assertTrue(config.is_valid_for_arch("gfx942")) + + def test_is_valid_for_arch_unsupported(self): + config = GroupedConvKernelConfig( + tile=self._make_tile(), + trait=self._make_trait(), + variant=GroupedConvVariant.FORWARD, + ndim_spatial=2, + arch="gfx942", + layout=GroupedConvLayout.NHWGC, + vector_sizes=(4, 4, 4), + ) + self.assertFalse(config.is_valid_for_arch("gfx600")) + + +# ============================================================================= +# TestGroupedConvTypeMappings +# ============================================================================= + + +class TestGroupedConvTypeMappings(unittest.TestCase): + """Test GroupedConvTypeMappings class.""" + + def test_dtype_to_ck_fp16(self): + self.assertEqual(GroupedConvTypeMappings.DTYPE_TO_CK["fp16"], "half_t") + + def test_dtype_to_ck_bf16(self): + self.assertIn("bf16", GroupedConvTypeMappings.DTYPE_TO_CK) + + def test_dtype_to_ck_fp32(self): + self.assertIn("fp32", GroupedConvTypeMappings.DTYPE_TO_CK) + + def test_get_layouts_2d_has_in_wei_out_keys(self): + layouts = GroupedConvTypeMappings.get_layouts(2) + self.assertIn("in", layouts) + self.assertIn("wei", layouts) + self.assertIn("out", layouts) + + def test_get_layouts_2d_returns_dict(self): + layouts = GroupedConvTypeMappings.get_layouts(2) + self.assertIsInstance(layouts, dict) + + def test_get_layouts_1d(self): + layouts = GroupedConvTypeMappings.get_layouts(1) + self.assertIn("in", layouts) + self.assertIn("wei", layouts) + self.assertIn("out", layouts) + + def test_get_layouts_3d(self): + layouts = GroupedConvTypeMappings.get_layouts(3) + self.assertIn("in", layouts) + self.assertIn("wei", layouts) + self.assertIn("out", layouts) + + +# ============================================================================= +# TestCKTileGroupedConvKernelGenerator +# ============================================================================= + + +class TestCKTileGroupedConvKernelGenerator(unittest.TestCase): + """Test CKTileGroupedConvKernelGenerator.generate().""" + + def _make_config(self): + tile = TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16) + trait = GroupedConvTraitConfig( + "mem", + "cshuffle", + "intrawave", + False, + False, + False, + double_smem_buffer=False, + num_groups_to_merge=1, + ) + return GroupedConvKernelConfig( + tile=tile, + trait=trait, + variant=GroupedConvVariant.FORWARD, + ndim_spatial=2, + arch="gfx942", + layout=GroupedConvLayout.NHWGC, + vector_sizes=(4, 4, 4), + ) + + def test_generate_contains_pragma_once(self): + gen = CKTileGroupedConvKernelGenerator("fp16") + config = self._make_config() + result = gen.generate(config) + self.assertIn("#pragma once", result) + + def test_generate_contains_forward_kernel_include(self): + gen = CKTileGroupedConvKernelGenerator("fp16") + config = self._make_config() + result = gen.generate(config) + self.assertIn("grouped_convolution_forward_kernel.hpp", result) + + def test_generate_returns_non_empty_string(self): + gen = CKTileGroupedConvKernelGenerator("fp16") + config = self._make_config() + result = gen.generate(config) + self.assertIsInstance(result, str) + self.assertGreater(len(result), 100) + + def test_generate_valid_cpp_structure(self): + gen = CKTileGroupedConvKernelGenerator("fp16") + config = self._make_config() + result = gen.generate(config) + self.assertIn("#include", result) + self.assertIn("ck_tile", result) + + +# ============================================================================= +# TestGroupedConvDispatcherWrapperGenerator +# ============================================================================= + + +class TestGroupedConvDispatcherWrapperGenerator(unittest.TestCase): + """Test GroupedConvDispatcherWrapperGenerator.generate().""" + + def _make_config(self): + tile = TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16) + trait = GroupedConvTraitConfig( + "mem", + "cshuffle", + "intrawave", + False, + False, + False, + double_smem_buffer=False, + num_groups_to_merge=1, + ) + return GroupedConvKernelConfig( + tile=tile, + trait=trait, + variant=GroupedConvVariant.FORWARD, + ndim_spatial=2, + arch="gfx942", + layout=GroupedConvLayout.NHWGC, + vector_sizes=(4, 4, 4), + ) + + def test_generate_contains_dispatcher_registration(self): + gen = GroupedConvDispatcherWrapperGenerator("fp16") + config = self._make_config() + kernel_path = DISPATCHER_DIR / "build" / "generated" / "test_kernel.hpp" + output_dir = DISPATCHER_DIR / "build" / "generated" + result = gen.generate(config, kernel_path, output_dir) + self.assertIn("dispatcher", result) + self.assertIn("KernelKey", result) + self.assertIn("KernelInstancePtr", result) + + def test_generate_contains_pragma_once(self): + gen = GroupedConvDispatcherWrapperGenerator("fp16") + config = self._make_config() + kernel_path = DISPATCHER_DIR / "build" / "generated" / "test_kernel.hpp" + output_dir = DISPATCHER_DIR / "build" / "generated" + result = gen.generate(config, kernel_path, output_dir) + self.assertIn("#pragma once", result) + + def test_generate_valid_cpp(self): + gen = GroupedConvDispatcherWrapperGenerator("fp16") + config = self._make_config() + kernel_path = DISPATCHER_DIR / "build" / "generated" / "test_kernel.hpp" + output_dir = DISPATCHER_DIR / "build" / "generated" + result = gen.generate(config, kernel_path, output_dir) + self.assertIn("#include", result) + self.assertIn("namespace", result) + + +# ============================================================================= +# TestUnifiedGroupedConvCodegen +# ============================================================================= + + +class TestUnifiedGroupedConvCodegen(unittest.TestCase): + """Test UnifiedGroupedConvCodegen.generate_all().""" + + def test_generate_all_returns_dict_with_expected_keys(self): + output_dir = DISPATCHER_DIR / "build" / "generated" / "grouped_conv" + output_dir.mkdir(parents=True, exist_ok=True) + codegen = UnifiedGroupedConvCodegen( + output_dir=output_dir, + datatype="fp16", + ndim_spatial=2, + gpu_target="gfx942", + ) + with patch.object( + codegen, + "_get_configs", + return_value=[], # Mock empty config list for fast test + ): + results = codegen.generate_all(parallel=False) + self.assertIn("kernels", results) + self.assertIn("wrappers", results) + self.assertIn("failed", results) + self.assertIsInstance(results["kernels"], list) + self.assertIsInstance(results["wrappers"], list) + self.assertIsInstance(results["failed"], list) + + def test_generate_all_with_mock_config_produces_output(self): + output_dir = DISPATCHER_DIR / "build" / "generated" / "grouped_conv_test" + output_dir.mkdir(parents=True, exist_ok=True) + codegen = UnifiedGroupedConvCodegen( + output_dir=output_dir, + datatype="fp16", + ndim_spatial=2, + gpu_target="gfx942", + ) + # Use a real config - patch the config source to return one config + tile = TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16) + trait = GroupedConvTraitConfig( + "mem", + "cshuffle", + "intrawave", + False, + False, + False, + double_smem_buffer=False, + num_groups_to_merge=1, + ) + config = GroupedConvKernelConfig( + tile=tile, + trait=trait, + variant=GroupedConvVariant.FORWARD, + ndim_spatial=2, + arch="gfx942", + layout=GroupedConvLayout.NHWGC, + vector_sizes=(4, 4, 4), + ) + + with patch.object(codegen, "_get_configs", return_value=[config]): + results = codegen.generate_all(parallel=False) + self.assertIsInstance(results, dict) + self.assertIn("kernels", results) + + +# ============================================================================= +# TestSharedImports +# ============================================================================= + + +class TestSharedImports(unittest.TestCase): + """Verify TileConfig from codegen_common and GroupedConvTraitConfig extends TraitConfigBase.""" + + def test_tile_config_has_expected_fields(self): + """TileConfig from codegen_common has tile_m, tile_n, tile_k, etc.""" + tc = TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16) + self.assertEqual(tc.tile_m, 128) + self.assertEqual(tc.tile_n, 128) + self.assertEqual(tc.tile_k, 32) + self.assertEqual(tc.warp_m, 2) + self.assertEqual(tc.warp_n, 2) + self.assertEqual(tc.warp_k, 1) + self.assertEqual(tc.warp_tile_m, 32) + self.assertEqual(tc.warp_tile_n, 32) + self.assertEqual(tc.warp_tile_k, 16) + + def test_tile_config_is_from_codegen_common(self): + """TileConfig used by grouped conv is the same as codegen_common.TileConfig.""" + tc = TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16) + self.assertTrue(tc.is_valid()) + + def test_grouped_conv_trait_config_extends_trait_config_base(self): + """GroupedConvTraitConfig extends TraitConfigBase.""" + self.assertTrue(issubclass(GroupedConvTraitConfig, TraitConfigBase)) + + def test_grouped_conv_trait_config_has_double_smem_buffer(self): + """GroupedConvTraitConfig has double_smem_buffer field.""" + trait = GroupedConvTraitConfig( + "mem", + "cshuffle", + "intrawave", + False, + False, + False, + double_smem_buffer=True, + num_groups_to_merge=2, + ) + self.assertTrue(trait.double_smem_buffer) + self.assertEqual(trait.num_groups_to_merge, 2) + + def test_grouped_conv_trait_config_has_num_groups_to_merge(self): + """GroupedConvTraitConfig has num_groups_to_merge field.""" + trait = GroupedConvTraitConfig( + "mem", + "cshuffle", + "intrawave", + False, + False, + False, + double_smem_buffer=False, + num_groups_to_merge=4, + ) + self.assertEqual(trait.num_groups_to_merge, 4) + + def test_grouped_conv_trait_config_inherits_base_fields(self): + """GroupedConvTraitConfig inherits pipeline, epilogue, scheduler from base.""" + trait = GroupedConvTraitConfig( + "compv4", + "cshuffle", + "intrawave", + True, + True, + True, + double_smem_buffer=False, + num_groups_to_merge=1, + ) + self.assertEqual(trait.pipeline, "compv4") + self.assertEqual(trait.epilogue, "cshuffle") + self.assertEqual(trait.scheduler, "intrawave") + self.assertTrue(trait.pad_m) + self.assertTrue(trait.pad_n) + self.assertTrue(trait.pad_k) + + +# ============================================================================= +# TestTwoStageBwdWeightCodegen +# ============================================================================= + + +def _make_two_stage_config(): + """Helper: create a two-stage bwd_weight config.""" + return GroupedConvKernelConfig( + tile=TileConfig(16, 64, 64, 1, 4, 1, 16, 16, 32), + trait=GroupedConvTraitConfig( + pipeline="compv3", + epilogue="cshuffle", + scheduler="intrawave", + pad_m=True, + pad_n=True, + pad_k=True, + two_stage=True, + ), + variant=GroupedConvVariant.BACKWARD_WEIGHT, + ndim_spatial=2, + arch="gfx942", + ) + + +class TestTwoStageBwdWeightCodegen(unittest.TestCase): + """Tests for two-stage backward weight kernel generation.""" + + def test_kernel_name_contains_2stage(self): + config = _make_two_stage_config() + name = config.name("fp16") + self.assertIn("_2stage", name) + self.assertIn("bwd_weight", name) + + def test_single_stage_name_has_no_2stage(self): + config = _make_two_stage_config() + config.trait.two_stage = False + name = config.name("fp16") + self.assertNotIn("_2stage", name) + + def test_generate_contains_elementwise_include(self): + config = _make_two_stage_config() + gen = CKTileGroupedConvKernelGenerator( + "fp16", GroupedConvVariant.BACKWARD_WEIGHT + ) + code = gen.generate(config) + self.assertIn("elementwise.hpp", code) + + def test_generate_contains_workspace_type(self): + config = _make_two_stage_config() + gen = CKTileGroupedConvKernelGenerator( + "fp16", GroupedConvVariant.BACKWARD_WEIGHT + ) + code = gen.generate(config) + self.assertIn("WorkspaceDataType", code) + + def test_generate_contains_elementwise_kernel(self): + config = _make_two_stage_config() + gen = CKTileGroupedConvKernelGenerator( + "fp16", GroupedConvVariant.BACKWARD_WEIGHT + ) + code = gen.generate(config) + self.assertIn("ElementWiseKernel", code) + + def test_generate_contains_launch_kernel_time_mask(self): + config = _make_two_stage_config() + gen = CKTileGroupedConvKernelGenerator( + "fp16", GroupedConvVariant.BACKWARD_WEIGHT + ) + code = gen.generate(config) + self.assertIn("launch_kernel_time_mask", code) + + def test_generate_forces_vector_size_c_to_1(self): + config = _make_two_stage_config() + gen = CKTileGroupedConvKernelGenerator( + "fp16", GroupedConvVariant.BACKWARD_WEIGHT + ) + code = gen.generate(config) + self.assertIn("VectorSizeC_TwoStage = 1", code) + + def test_generate_contains_workspace_memset(self): + config = _make_two_stage_config() + gen = CKTileGroupedConvKernelGenerator( + "fp16", GroupedConvVariant.BACKWARD_WEIGHT + ) + code = gen.generate(config) + self.assertIn("hipMemsetAsync", code) + + def test_single_stage_does_not_contain_workspace(self): + config = _make_two_stage_config() + config.trait.two_stage = False + gen = CKTileGroupedConvKernelGenerator( + "fp16", GroupedConvVariant.BACKWARD_WEIGHT + ) + code = gen.generate(config) + self.assertNotIn("WorkspaceDataType", code) + self.assertNotIn("ElementWiseKernel", code) + self.assertNotIn("launch_kernel_time_mask", code) + + def test_default_configs_include_two_stage(self): + from unified_grouped_conv_codegen import get_default_configs + + configs = get_default_configs( + arch="gfx942", + variants=[GroupedConvVariant.BACKWARD_WEIGHT], + ndims=[2], + ) + two_stage = [c for c in configs if c.trait.two_stage] + single_stage = [c for c in configs if not c.trait.two_stage] + self.assertGreater(len(two_stage), 0, "Should have two-stage configs") + self.assertGreater( + len(single_stage), 0, "Should still have single-stage configs" + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/dispatcher/tests/test_grouped_conv_config.cpp b/dispatcher/tests/test_grouped_conv_config.cpp new file mode 100644 index 0000000000..c9a1faeaf9 --- /dev/null +++ b/dispatcher/tests/test_grouped_conv_config.cpp @@ -0,0 +1,112 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// Unit tests for GroupedConvConfig using assert() and std::cout + +#include "ck_tile/dispatcher/grouped_conv_config.hpp" +#include +#include + +using namespace ck_tile::dispatcher; + +void test_grouped_conv_direction_enum() +{ + std::cout << " test_grouped_conv_direction_enum... "; + assert(GroupedConvSignatureInfo::direction_str(GroupedConvDirection::FORWARD) == + std::string("fwd")); + assert(GroupedConvSignatureInfo::direction_str(GroupedConvDirection::BACKWARD_DATA) == + std::string("bwd_data")); + assert(GroupedConvSignatureInfo::direction_str(GroupedConvDirection::BACKWARD_WEIGHT) == + std::string("bwd_weight")); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_signature_info() +{ + std::cout << " test_grouped_conv_signature_info... "; + GroupedConvSignatureInfo sig; + assert(sig.spatial_dim == 2); + assert(sig.direction == GroupedConvDirection::FORWARD); + assert(sig.in_type == "fp16"); + assert(sig.wei_type == "fp16"); + assert(sig.out_type == "fp16"); + assert(sig.acc_type == "fp32"); + assert(sig.num_groups == 1); + sig.in_type = "bf16"; + sig.num_groups = 4; + assert(sig.in_type == "bf16"); + assert(sig.num_groups == 4); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_algorithm_info() +{ + std::cout << " test_grouped_conv_algorithm_info... "; + GroupedConvAlgorithmInfo algo; + assert(algo.tile.m == 128); + assert(algo.tile.n == 128); + assert(algo.tile.k == 64); + assert(algo.pipeline == PipelineVersion::V4); + assert(algo.scheduler == PipelineScheduler::INTRAWAVE); + assert(GroupedConvAlgorithmInfo::pipeline_str(PipelineVersion::V4) == std::string("compv4")); + assert(GroupedConvAlgorithmInfo::scheduler_str(PipelineScheduler::INTRAWAVE) == + std::string("intrawave")); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_config() +{ + std::cout << " test_grouped_conv_config... "; + GroupedConvConfig cfg; + std::string name = cfg.name(); + assert(!name.empty()); + assert(name.find("grouped_conv_") != std::string::npos); + assert(name.find("fwd") != std::string::npos); + assert(name.find("fp16") != std::string::npos); + assert(name.find("2d") != std::string::npos); + + std::string brief = cfg.brief(); + assert(!brief.empty()); + assert(brief.find("2D") != std::string::npos || brief.find("Grouped") != std::string::npos); + + std::string detailed = cfg.detailed(); + assert(!detailed.empty()); + assert(detailed.find("Signature:") != std::string::npos); + assert(detailed.find("Algorithm:") != std::string::npos); + assert(detailed.find("Arch:") != std::string::npos); + std::cout << "PASSED\n"; +} + +void test_predefined_grouped_conv_configs() +{ + std::cout << " test_predefined_grouped_conv_configs... "; + configs::Memory mem_cfg; + assert(mem_cfg.algorithm.pipeline == PipelineVersion::MEMORY); + assert(mem_cfg.algorithm.tile.m == 128); + assert(mem_cfg.algorithm.tile.n == 32); + + configs::CompV3_Small compv3_small; + assert(compv3_small.algorithm.pipeline == PipelineVersion::V3); + assert(compv3_small.algorithm.tile.m == 16); + assert(compv3_small.algorithm.tile.n == 64); + + configs::CompV4 compv4; + assert(compv4.algorithm.pipeline == PipelineVersion::V4); + assert(compv4.algorithm.double_smem_buffer == true); + + configs::WMMA wmma_cfg; + assert(wmma_cfg.arch.name == "gfx1100"); + std::cout << "PASSED\n"; +} + +int main() +{ + std::cout << "\n=== Test Grouped Conv Config ===\n\n"; + test_grouped_conv_direction_enum(); + test_grouped_conv_signature_info(); + test_grouped_conv_algorithm_info(); + test_grouped_conv_config(); + test_predefined_grouped_conv_configs(); + std::cout << "\n=== All Tests Passed! ===\n\n"; + return 0; +} diff --git a/dispatcher/tests/test_grouped_conv_kernel_decl.cpp b/dispatcher/tests/test_grouped_conv_kernel_decl.cpp new file mode 100644 index 0000000000..7b28a451bc --- /dev/null +++ b/dispatcher/tests/test_grouped_conv_kernel_decl.cpp @@ -0,0 +1,141 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// Unit tests for GroupedConvKernelDecl using assert() and std::cout + +#include "ck_tile/dispatcher/grouped_conv_kernel_decl.hpp" +#include +#include + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::grouped_conv_decl; + +void test_grouped_conv_signature_builder() +{ + std::cout << " test_grouped_conv_signature_builder... "; + GroupedConvSignature sig; + sig.dtype("fp16").layout("nhwc").conv_type("forward").dims(2).groups(4); + assert(sig.dtype_in_ == "fp16"); + assert(sig.dtype_wei_ == "fp16"); + assert(sig.dtype_out_ == "fp16"); + assert(sig.layout_ == "nhwc"); + assert(sig.conv_op_ == "forward"); + assert(sig.num_dims_ == 2); + assert(sig.groups_ == 4); + assert(sig.op_str() == "fwd"); + sig.conv_type("bwd_data"); + assert(sig.op_str() == "bwd_data"); + sig.conv_type("bwd_weight"); + assert(sig.op_str() == "bwd_weight"); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_algorithm_builder() +{ + std::cout << " test_grouped_conv_algorithm_builder... "; + GroupedConvAlgorithm algo; + algo.tile(128, 128, 64) + .wave(2, 2, 1) + .warp(32, 32, 16) + .pipeline("compv4") + .scheduler("intrawave"); + assert(algo.tile_m_ == 128); + assert(algo.tile_n_ == 128); + assert(algo.tile_k_ == 64); + assert(algo.wave_m_ == 2); + assert(algo.wave_n_ == 2); + assert(algo.warp_m_ == 32); + assert(algo.warp_n_ == 32); + assert(algo.pipeline_ == "compv4"); + assert(algo.scheduler_ == "intrawave"); + assert(!algo.needs_expansion()); + algo.wave_m_ = ANY_INT; + assert(algo.needs_wave_expansion()); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_kernel_decl() +{ + std::cout << " test_grouped_conv_kernel_decl... "; + GroupedConvSignature sig; + sig.dtype("fp16").layout("nhwc").conv_type("forward").dims(2); + GroupedConvAlgorithm algo; + algo.tile(128, 128, 64).wave(2, 2, 1).warp(32, 32, 16); + GroupedConvKernelDecl decl(sig, algo, "gfx942"); + std::string name = decl.name(); + assert(!name.empty()); + assert(name.find("grouped_conv_") != std::string::npos); + assert(name.find("fwd") != std::string::npos); + assert(name.find("fp16") != std::string::npos); + assert(name.find("128x128x64") != std::string::npos); + assert(!decl.has_wildcards()); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_kernel_set() +{ + std::cout << " test_grouped_conv_kernel_set... "; + GroupedConvKernelSet set; + set.add("fp16", "nhwc", "forward", 128, 128); + assert(set.size() == 1); + set.add("fp16", "nhwc", "forward", 256, 256); + assert(set.size() == 2); + const auto& decls = set.declarations(); + assert(decls[0].algorithm.tile_n_ == 128); + assert(decls[0].algorithm.tile_k_ == 128); + assert(decls[1].algorithm.tile_n_ == 256); + assert(decls[1].algorithm.tile_k_ == 256); + set.tag("test_set"); + assert(set.tag() == "test_set"); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_kernel_set_merge() +{ + std::cout << " test_grouped_conv_kernel_set_merge... "; + GroupedConvKernelSet set1; + set1.add("fp16", "nhwc", "forward", 128, 128); + GroupedConvKernelSet set2; + set2.add("fp16", "nhwc", "forward", 256, 256); + set1.merge(set2); + assert(set1.size() == 2); + assert(set1.declarations()[0].algorithm.tile_n_ == 128); + assert(set1.declarations()[1].algorithm.tile_n_ == 256); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_kernel_set_registry() +{ + std::cout << " test_grouped_conv_kernel_set_registry... "; + auto& reg = GroupedConvKernelSetRegistry::instance(); + reg.clear(); + + GroupedConvKernelSet set; + set.add("fp16", "nhwc", "forward", 128, 128); + reg.register_set("gconv_test", set); + assert(reg.has("gconv_test")); + assert(reg.size() >= 1); + + const auto& retrieved = reg.get("gconv_test"); + assert(retrieved.size() == 1); + + const auto& empty = reg.get("nonexistent"); + assert(empty.size() == 0); + + reg.clear(); + assert(!reg.has("gconv_test")); + std::cout << "PASSED\n"; +} + +int main() +{ + std::cout << "\n=== Test Grouped Conv Kernel Decl ===\n\n"; + test_grouped_conv_signature_builder(); + test_grouped_conv_algorithm_builder(); + test_grouped_conv_kernel_decl(); + test_grouped_conv_kernel_set(); + test_grouped_conv_kernel_set_merge(); + test_grouped_conv_kernel_set_registry(); + std::cout << "\n=== All Tests Passed! ===\n\n"; + return 0; +} diff --git a/dispatcher/tests/test_grouped_conv_problem.cpp b/dispatcher/tests/test_grouped_conv_problem.cpp new file mode 100644 index 0000000000..a6a4d8ba08 --- /dev/null +++ b/dispatcher/tests/test_grouped_conv_problem.cpp @@ -0,0 +1,245 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// Unit tests for GroupedConvProblem using assert() and std::cout + +#include "ck_tile/dispatcher/grouped_conv_problem.hpp" +#include +#include +#include + +using namespace ck_tile::dispatcher; + +void test_grouped_conv_problem_defaults() +{ + std::cout << " test_grouped_conv_problem_defaults... "; + GroupedConvProblem p; + assert(p.N == 1); + assert(p.C == 64); + assert(p.K == 64); + assert(p.G == 1); + assert(p.Hi() == 28); + assert(p.Wi() == 28); + assert(p.Y() == 3); + assert(p.X() == 3); + assert(p.op == GroupedConvOp::Forward); + assert(p.stride[0] == 1 && p.stride[1] == 1 && p.stride[2] == 1); + assert(p.padding[0] == 0 && p.padding[1] == 0 && p.padding[2] == 0); + assert(p.dilation[0] == 1 && p.dilation[1] == 1 && p.dilation[2] == 1); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_problem_2d() +{ + std::cout << " test_grouped_conv_problem_2d... "; + GroupedConvProblem p(4, 64, 128, 28, 28, 3, 3); + p.compute_output_size(); + assert(p.N == 4); + assert(p.C == 64); + assert(p.K == 128); + assert(p.Hi() == 28); + assert(p.Wi() == 28); + assert(p.Y() == 3); + assert(p.X() == 3); + assert(p.Ho() == 26); + assert(p.Wo() == 26); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_problem_strided() +{ + std::cout << " test_grouped_conv_problem_strided... "; + GroupedConvProblem p; + p.N = 1; + p.C = 64; + p.K = 64; + p.G = 1; + p.input_spatial = {1, 14, 14}; + p.filter_spatial = {1, 3, 3}; + p.stride = {1, 2, 2}; + p.padding = {0, 1, 1}; + p.dilation = {1, 1, 1}; + p.compute_output_size(); + assert(p.Ho() == 7); + assert(p.Wo() == 7); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_problem_grouped() +{ + std::cout << " test_grouped_conv_problem_grouped... "; + GroupedConvProblem p; + p.N = 2; + p.C = 64; + p.K = 64; + p.G = 4; + p.input_spatial = {1, 14, 14}; + p.filter_spatial = {1, 3, 3}; + p.stride = {1, 1, 1}; + p.padding = {0, 0, 0}; + p.dilation = {1, 1, 1}; + p.compute_output_size(); + assert(p.G == 4); + assert(p.C % p.G == 0); + assert(p.K % p.G == 0); + assert(p.is_valid()); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_problem_depthwise() +{ + std::cout << " test_grouped_conv_problem_depthwise... "; + GroupedConvProblem p; + p.N = 2; + p.C = 64; + p.K = 64; + p.G = 64; + p.input_spatial = {1, 14, 14}; + p.filter_spatial = {1, 3, 3}; + p.stride = {1, 1, 1}; + p.padding = {0, 0, 0}; + p.dilation = {1, 1, 1}; + p.compute_output_size(); + assert(p.is_depthwise()); + assert(p.G == p.C && p.G == p.K); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_problem_pointwise() +{ + std::cout << " test_grouped_conv_problem_pointwise... "; + GroupedConvProblem p; + p.N = 2; + p.C = 64; + p.K = 128; + p.G = 1; + p.input_spatial = {1, 14, 14}; + p.filter_spatial = {1, 1, 1}; + p.stride = {1, 1, 1}; + p.padding = {0, 0, 0}; + p.dilation = {1, 1, 1}; + p.compute_output_size(); + assert(p.is_pointwise()); + assert(p.Y() == 1 && p.X() == 1); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_problem_flops() +{ + std::cout << " test_grouped_conv_problem_flops... "; + GroupedConvProblem p; + p.N = 2; + p.C = 64; + p.K = 64; + p.G = 1; + p.input_spatial = {1, 14, 14}; + p.filter_spatial = {1, 3, 3}; + p.stride = {1, 1, 1}; + p.padding = {0, 0, 0}; + p.dilation = {1, 1, 1}; + p.compute_output_size(); + double flops = p.get_flops(); + assert(flops > 0); + assert(flops == 2.0 * p.N * p.K * p.Ho() * p.Wo() * (p.C / p.G) * p.Y() * p.X()); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_problem_is_valid() +{ + std::cout << " test_grouped_conv_problem_is_valid... "; + GroupedConvProblem p; + p.N = 1; + p.C = 64; + p.K = 64; + p.G = 1; + p.input_spatial = {1, 14, 14}; + p.filter_spatial = {1, 3, 3}; + p.compute_output_size(); + assert(p.is_valid()); + + p.N = 0; + assert(!p.is_valid()); + p.N = 1; + + p.C = 0; + assert(!p.is_valid()); + p.C = 64; + + p.K = 0; + assert(!p.is_valid()); + p.K = 64; + + p.G = 0; + assert(!p.is_valid()); + p.G = 1; + + p.C = 64; + p.K = 64; + p.G = 3; + assert(!p.is_valid()); + p.G = 4; + assert(p.is_valid()); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_problem_builder() +{ + std::cout << " test_grouped_conv_problem_builder... "; + auto p = GroupedConvProblemBuilder() + .batch(8) + .channels(128, 256) + .groups(4) + .input_size(32, 32) + .filter_size(3, 3) + .stride(2, 2) + .padding(1, 1) + .dilation(1, 1) + .operation(GroupedConvOp::Forward) + .build(); + assert(p.N == 8); + assert(p.C == 128); + assert(p.K == 256); + assert(p.G == 4); + assert(p.Hi() == 32); + assert(p.Wi() == 32); + assert(p.Y() == 3); + assert(p.X() == 3); + assert(p.stride[1] == 2 && p.stride[2] == 2); + assert(p.padding[1] == 1 && p.padding[2] == 1); + assert(p.op == GroupedConvOp::Forward); + assert(p.is_valid()); + + bool threw = false; + try + { + (void)GroupedConvProblemBuilder() + .batch(0) + .channels(64, 64) + .groups(1) + .input_size(14, 14) + .filter_size(3, 3) + .build(); + } + catch(const std::invalid_argument&) + { + threw = true; + } + assert(threw); + std::cout << "PASSED\n"; +} + +int main() +{ + std::cout << "\n=== Test Grouped Conv Problem ===\n\n"; + test_grouped_conv_problem_defaults(); + test_grouped_conv_problem_2d(); + test_grouped_conv_problem_strided(); + test_grouped_conv_problem_grouped(); + test_grouped_conv_problem_depthwise(); + test_grouped_conv_problem_pointwise(); + test_grouped_conv_problem_flops(); + test_grouped_conv_problem_is_valid(); + test_grouped_conv_problem_builder(); + std::cout << "\n=== All Tests Passed! ===\n\n"; + return 0; +} diff --git a/dispatcher/tests/test_grouped_conv_registry.cpp b/dispatcher/tests/test_grouped_conv_registry.cpp new file mode 100644 index 0000000000..47d13a9997 --- /dev/null +++ b/dispatcher/tests/test_grouped_conv_registry.cpp @@ -0,0 +1,230 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +/// Unit tests for GroupedConvRegistry and GroupedConvDispatcher using assert() and std::cout + +#include "ck_tile/dispatcher/grouped_conv_utils.hpp" +#include +#include +#include +#include + +using namespace ck_tile::dispatcher; +using namespace ck_tile::dispatcher::grouped_conv_decl; + +void test_grouped_conv_registry_basic() +{ + std::cout << " test_grouped_conv_registry_basic... "; + GroupedConvRegistry& reg = GroupedConvRegistry::instance(); + reg.clear(); + + reg.set_name("test_registry"); + assert(reg.name() == "test_registry"); + + assert(reg.size() == 0); + assert(reg.empty()); + + reg.clear(); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_registry_register_set() +{ + std::cout << " test_grouped_conv_registry_register_set... "; + GroupedConvRegistry& reg = GroupedConvRegistry::instance(); + reg.clear(); + + GroupedConvKernelSet set; + set.add("fp16", "nhwc", "forward", 128, 128); + set.add("fp16", "nhwc", "forward", 256, 256); + + bool ok = reg.register_set(set); + assert(ok); + assert(reg.size() == 2); + assert(!reg.empty()); + + reg.clear(); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_registry_all_kernels() +{ + std::cout << " test_grouped_conv_registry_all_kernels... "; + GroupedConvRegistry& reg = GroupedConvRegistry::instance(); + reg.clear(); + + GroupedConvKernelSet set; + set.add("fp16", "nhwc", "forward", 128, 128); + reg.register_set(set); + + auto all = reg.all_kernels(); + assert(all.size() == 1); + assert(all[0]->name().find("grouped_conv_") != std::string::npos); + + reg.clear(); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_registry_clear() +{ + std::cout << " test_grouped_conv_registry_clear... "; + GroupedConvRegistry& reg = GroupedConvRegistry::instance(); + reg.clear(); + + GroupedConvKernelSet set; + set.add("fp16", "nhwc", "forward", 128, 128); + reg.register_set(set); + assert(reg.size() == 1); + + reg.clear(); + assert(reg.size() == 0); + assert(reg.empty()); + + reg.clear(); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_registry_thread_safe() +{ + std::cout << " test_grouped_conv_registry_thread_safe... "; + GroupedConvRegistry& reg = GroupedConvRegistry::instance(); + reg.clear(); + + const int num_threads = 4; + const int sets_per_thread = 10; + std::vector threads; + std::atomic success_count{0}; + + for(int t = 0; t < num_threads; t++) + { + threads.emplace_back([t, ®, &success_count]() { + for(int k = 0; k < sets_per_thread; k++) + { + GroupedConvKernelSet set; + set.add("fp16", "nhwc", "forward", 128 + t * 32 + k, 128); + if(reg.register_set(set)) + { + success_count++; + } + } + }); + } + + for(auto& th : threads) + th.join(); + + assert(reg.size() == num_threads * sets_per_thread); + assert(success_count.load() == num_threads * sets_per_thread); + + reg.clear(); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_registry_export_json() +{ + std::cout << " test_grouped_conv_registry_export_json... "; + GroupedConvRegistry& reg = GroupedConvRegistry::instance(); + reg.clear(); + + GroupedConvKernelSet set; + set.add("fp16", "nhwc", "forward", 128, 128); + reg.register_set(set); + + std::string json = reg.export_json(false); + assert(!json.empty()); + assert(json.find("\"kernels\"") != std::string::npos); + assert(json.find("\"metadata\"") != std::string::npos); + assert(json.find("grouped_conv_") != std::string::npos); + + std::string json_stats = reg.export_json(true); + assert(json_stats.find("\"statistics\"") != std::string::npos); + + reg.clear(); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_registry_filter() +{ + std::cout << " test_grouped_conv_registry_filter... "; + GroupedConvRegistry& reg = GroupedConvRegistry::instance(); + reg.clear(); + + GroupedConvKernelSet set; + set.add("fp16", "nhwc", "forward", 128, 128); + set.add("fp16", "nhwc", "forward", 256, 256); + set.add("bf16", "nhwc", "forward", 128, 128); + reg.register_set(set); + + auto fp16_only = + reg.filter([](const GroupedConvKernelInstance& k) { return k.key().dtype_in == "fp16"; }); + assert(fp16_only.size() == 2); + + auto large_tile = reg.filter([](const GroupedConvKernelInstance& k) { + return k.key().tile_m >= 256 || k.key().tile_n >= 256; + }); + assert(large_tile.size() >= 1); + + reg.clear(); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_dispatcher_basic() +{ + std::cout << " test_grouped_conv_dispatcher_basic... "; + GroupedConvRegistry& reg = GroupedConvRegistry::instance(); + reg.clear(); + + GroupedConvKernelSet set; + set.add("fp16", "nhwc", "forward", 128, 128); + reg.register_set(set); + + GroupedConvDispatcher dispatcher(®); + GroupedConvProblem problem = grouped_conv_utils::create_grouped_conv2d_problem( + 4, 64, 128, 28, 28, 3, 3, 1, 1, GroupedConvOp::Forward); + + float time = dispatcher.run(problem, nullptr); + assert(time >= 0.0f); + + reg.clear(); + std::cout << "PASSED\n"; +} + +void test_grouped_conv_dispatcher_select() +{ + std::cout << " test_grouped_conv_dispatcher_select... "; + GroupedConvRegistry& reg = GroupedConvRegistry::instance(); + reg.clear(); + + GroupedConvKernelSet set; + set.add("fp16", "nhwc", "forward", 128, 128); + set.add("fp16", "nhwc", "forward", 256, 256); + reg.register_set(set); + + GroupedConvDispatcher dispatcher(®); + GroupedConvProblem problem = grouped_conv_utils::create_grouped_conv2d_problem( + 4, 64, 128, 28, 28, 3, 3, 1, 1, GroupedConvOp::Forward); + + const auto* selected = dispatcher.select(problem); + assert(selected != nullptr); + assert(selected->name().find("grouped_conv_") != std::string::npos); + assert(selected->matches(problem)); + + reg.clear(); + std::cout << "PASSED\n"; +} + +int main() +{ + std::cout << "\n=== Test Grouped Conv Registry ===\n\n"; + test_grouped_conv_registry_basic(); + test_grouped_conv_registry_register_set(); + test_grouped_conv_registry_all_kernels(); + test_grouped_conv_registry_clear(); + test_grouped_conv_registry_thread_safe(); + test_grouped_conv_registry_export_json(); + test_grouped_conv_registry_filter(); + test_grouped_conv_dispatcher_basic(); + test_grouped_conv_dispatcher_select(); + std::cout << "\n=== All Tests Passed! ===\n\n"; + return 0; +} diff --git a/dispatcher/tests/test_grouped_conv_utils.py b/dispatcher/tests/test_grouped_conv_utils.py new file mode 100644 index 0000000000..9d0638dc08 --- /dev/null +++ b/dispatcher/tests/test_grouped_conv_utils.py @@ -0,0 +1,349 @@ +#!/usr/bin/env python3 + +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +""" +TDD tests for python/grouped_conv_utils.py -- grouped convolution Python utilities. + +Phase 1 TDD: tests written BEFORE implementation exists. +Run: python3 -m pytest tests/test_grouped_conv_utils.py -v +""" + +import sys +import unittest +from pathlib import Path + +SCRIPT_DIR = Path(__file__).parent.resolve() +DISPATCHER_DIR = SCRIPT_DIR.parent +sys.path.insert(0, str(DISPATCHER_DIR / "python")) +sys.path.insert(0, str(DISPATCHER_DIR / "codegen")) + +from dispatcher_common import ValidationResultBase # noqa: E402 +from grouped_conv_utils import ( # noqa: E402 + GroupedConvValidationResult, + validate_grouped_conv_config, + auto_correct_grouped_conv_config, + get_grouped_conv_default_config, + GroupedConvDataType, + format_grouped_conv_summary, +) + + +# ============================================================================= +# VALID CONFIG FIXTURES +# ============================================================================= + + +def make_valid_grouped_conv_config(): + """Return a valid grouped conv config dict for gfx942.""" + return { + "tile_config": { + "tile_k": 128, + "tile_c": 128, + "wave_m": 2, + "wave_n": 2, + "wave_k": 1, + "warp_m": 32, + "warp_n": 32, + "warp_k": 16, + }, + "trait_config": { + "pipeline": "compv4", + "epilogue": "cshuffle", + "scheduler": "intrawave", + }, + "variant": "2d_fwd", + "ndim_spatial": 2, + "arch": "gfx942", + "layout": "nhwgc", + "dtype": "fp16", + } + + +# ============================================================================= +# TestGroupedConvValidationResult +# ============================================================================= + + +class TestGroupedConvValidationResult(unittest.TestCase): + """Tests for GroupedConvValidationResult dataclass.""" + + def test_inherits_from_validation_result_base(self): + """GroupedConvValidationResult should inherit from ValidationResultBase.""" + self.assertTrue( + issubclass(GroupedConvValidationResult, ValidationResultBase), + "GroupedConvValidationResult must inherit from ValidationResultBase", + ) + + def test_valid_result_has_is_valid(self): + """Valid result has is_valid=True.""" + vr = GroupedConvValidationResult(is_valid=True) + self.assertTrue(vr.is_valid) + + def test_invalid_result_has_is_valid_false(self): + """Invalid result has is_valid=False.""" + vr = GroupedConvValidationResult(is_valid=False, errors=["bad config"]) + self.assertFalse(vr.is_valid) + + def test_has_errors_list(self): + """Result has errors list.""" + vr = GroupedConvValidationResult( + is_valid=False, + errors=["invalid wave", "invalid trait"], + ) + self.assertEqual(len(vr.errors), 2) + self.assertIn("invalid wave", vr.errors) + self.assertIn("invalid trait", vr.errors) + + def test_has_warnings_list(self): + """Result has warnings list.""" + vr = GroupedConvValidationResult( + is_valid=True, + warnings=["deprecated option"], + ) + self.assertEqual(len(vr.warnings), 1) + self.assertIn("deprecated option", vr.warnings) + + def test_has_suggested_fixes_dict(self): + """Result has suggested_fixes dict.""" + vr = GroupedConvValidationResult( + is_valid=False, + suggested_fixes={"wave_m": 2, "wave_n": 2}, + ) + self.assertIn("wave_m", vr.suggested_fixes) + self.assertEqual(vr.suggested_fixes["wave_m"], 2) + self.assertIn("wave_n", vr.suggested_fixes) + self.assertEqual(vr.suggested_fixes["wave_n"], 2) + + def test_default_empty_errors_warnings_fixes(self): + """Default result has empty errors, warnings, suggested_fixes.""" + vr = GroupedConvValidationResult(is_valid=True) + self.assertEqual(vr.errors, []) + self.assertEqual(vr.warnings, []) + self.assertEqual(vr.suggested_fixes, {}) + + +# ============================================================================= +# TestValidateGroupedConvConfig +# ============================================================================= + + +class TestValidateGroupedConvConfig(unittest.TestCase): + """Tests for validate_grouped_conv_config.""" + + def test_valid_config_passes(self): + """Valid config should pass validation.""" + config = make_valid_grouped_conv_config() + result = validate_grouped_conv_config(config) + self.assertTrue(result.is_valid, f"Expected valid, got errors: {result.errors}") + self.assertEqual(result.errors, []) + + def test_invalid_wave_config_fails(self): + """Invalid wave config should fail validation.""" + config = make_valid_grouped_conv_config() + config["tile_config"]["wave_m"] = 3 + config["tile_config"]["wave_n"] = 3 + result = validate_grouped_conv_config(config) + self.assertFalse(result.is_valid) + self.assertGreater(len(result.errors), 0) + error_str = " ".join(result.errors).lower() + self.assertIn("wave", error_str) + + def test_invalid_trait_fails(self): + """Invalid trait combination should fail validation.""" + config = make_valid_grouped_conv_config() + config["trait_config"]["pipeline"] = "compv4" + config["trait_config"]["epilogue"] = "cshuffle" + config["trait_config"]["scheduler"] = "interwave" # Invalid combo + result = validate_grouped_conv_config(config) + self.assertFalse(result.is_valid) + self.assertGreater(len(result.errors), 0) + error_str = " ".join(result.errors).lower() + self.assertIn("trait", error_str) + + def test_missing_fields_fails(self): + """Config with missing required fields should fail validation.""" + config = {"arch": "gfx942"} # Missing tile_config, trait_config, etc. + result = validate_grouped_conv_config(config) + self.assertFalse(result.is_valid) + self.assertGreater(len(result.errors), 0) + + +# ============================================================================= +# TestAutoCorrectGroupedConvConfig +# ============================================================================= + + +class TestAutoCorrectGroupedConvConfig(unittest.TestCase): + """Tests for auto_correct_grouped_conv_config.""" + + def test_invalid_wave_gets_corrected(self): + """Invalid wave config should be auto-corrected.""" + config = make_valid_grouped_conv_config() + config["tile_config"]["wave_m"] = 3 + config["tile_config"]["wave_n"] = 3 + corrected, result = auto_correct_grouped_conv_config(config) + self.assertIsInstance(corrected, dict) + self.assertIsInstance(result, GroupedConvValidationResult) + # Corrected wave should be valid for arch + wave_m = corrected.get("tile_config", {}).get("wave_m") + wave_n = corrected.get("tile_config", {}).get("wave_n") + self.assertIn(wave_m, [1, 2, 4]) + self.assertIn(wave_n, [1, 2, 4]) + + def test_invalid_trait_gets_corrected(self): + """Invalid trait combination should be auto-corrected.""" + config = make_valid_grouped_conv_config() + config["trait_config"]["scheduler"] = "interwave" + config["trait_config"]["pipeline"] = "compv4" + config["trait_config"]["epilogue"] = "cshuffle" + corrected, result = auto_correct_grouped_conv_config(config) + self.assertIsInstance(corrected, dict) + self.assertIsInstance(result, GroupedConvValidationResult) + # Scheduler should be corrected to intrawave for compv4+cshuffle + scheduler = corrected.get("trait_config", {}).get("scheduler") + self.assertEqual(scheduler, "intrawave") + + +# ============================================================================= +# TestGetGroupedConvDefaultConfig +# ============================================================================= + + +class TestGetGroupedConvDefaultConfig(unittest.TestCase): + """Tests for get_grouped_conv_default_config.""" + + def test_returns_config(self): + """Should return a GroupedConvKernelConfig (or dict via to_dict).""" + config = get_grouped_conv_default_config("2d_fwd") + # Accepts both dataclass and dict + d = config.to_dict() if hasattr(config, "to_dict") else config + self.assertIsInstance(d, dict) + + def test_has_tile_config(self): + """Returned config has tile_config key.""" + config = get_grouped_conv_default_config("2d_fwd") + d = config.to_dict() if hasattr(config, "to_dict") else config + self.assertIn("tile_config", d) + self.assertIsInstance(d["tile_config"], dict) + + def test_has_trait_config(self): + """Returned config has trait_config key.""" + config = get_grouped_conv_default_config("2d_fwd") + d = config.to_dict() if hasattr(config, "to_dict") else config + self.assertIn("trait_config", d) + self.assertIsInstance(d["trait_config"], dict) + + def test_has_variant(self): + """Returned config has variant.""" + config = get_grouped_conv_default_config("2d_fwd") + d = config.to_dict() if hasattr(config, "to_dict") else config + self.assertIn("variant", d) + + def test_has_ndim_spatial(self): + """Returned config has ndim_spatial.""" + config = get_grouped_conv_default_config("2d_fwd") + d = config.to_dict() if hasattr(config, "to_dict") else config + self.assertIn("ndim_spatial", d) + + def test_has_arch(self): + """Returned config has arch.""" + config = get_grouped_conv_default_config("2d_fwd") + d = config.to_dict() if hasattr(config, "to_dict") else config + self.assertIn("arch", d) + + def test_has_layout(self): + """Returned config has layout.""" + config = get_grouped_conv_default_config("2d_fwd") + d = config.to_dict() if hasattr(config, "to_dict") else config + self.assertIn("layout", d) + + +# ============================================================================= +# TestGroupedConvDataType +# ============================================================================= + + +class TestGroupedConvDataType(unittest.TestCase): + """Tests for GroupedConvDataType enum.""" + + def test_fp16_exists(self): + """GroupedConvDataType has FP16.""" + self.assertIsNotNone(GroupedConvDataType.FP16) + + def test_bf16_exists(self): + """GroupedConvDataType has BF16.""" + self.assertIsNotNone(GroupedConvDataType.BF16) + + def test_fp32_exists(self): + """GroupedConvDataType has FP32.""" + self.assertIsNotNone(GroupedConvDataType.FP32) + + def test_fp8_exists(self): + """GroupedConvDataType has FP8.""" + self.assertIsNotNone(GroupedConvDataType.FP8) + + def test_bf8_exists(self): + """GroupedConvDataType has BF8.""" + self.assertIsNotNone(GroupedConvDataType.BF8) + + def test_int8_exists(self): + """GroupedConvDataType has INT8.""" + self.assertIsNotNone(GroupedConvDataType.INT8) + + def test_enum_values_unique(self): + """All enum values should be unique.""" + values = [ + GroupedConvDataType.FP16, + GroupedConvDataType.BF16, + GroupedConvDataType.FP32, + GroupedConvDataType.FP8, + GroupedConvDataType.BF8, + GroupedConvDataType.INT8, + ] + self.assertEqual(len(values), len(set(values))) + + +# ============================================================================= +# TestFormatGroupedConvSummary +# ============================================================================= + + +class TestFormatGroupedConvSummary(unittest.TestCase): + """Tests for format_grouped_conv_summary.""" + + def test_returns_non_empty_string(self): + """Should return a non-empty string.""" + config = make_valid_grouped_conv_config() + summary = format_grouped_conv_summary(config) + self.assertIsInstance(summary, str) + self.assertGreater(len(summary), 0) + + def test_contains_key_info(self): + """Summary should contain key config info (variant, arch, layout, dtype).""" + config = make_valid_grouped_conv_config() + summary = format_grouped_conv_summary(config) + # Should mention at least some of: variant, arch, layout, dtype + summary_lower = summary.lower() + has_key_info = ( + "2d" in summary_lower + or "fwd" in summary_lower + or "gfx" in summary_lower + or "nhwgc" in summary_lower + or "fp16" in summary_lower + ) + self.assertTrue( + has_key_info, + f"Summary should contain key info, got: {summary}", + ) + + def test_empty_config_returns_something(self): + """Empty or minimal config should still return a string.""" + summary = format_grouped_conv_summary({}) + self.assertIsInstance(summary, str) + self.assertGreaterEqual(len(summary), 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/dispatcher/tests/test_problem_extended.cpp b/dispatcher/tests/test_problem_extended.cpp index 21ea545292..ba6068e3ee 100644 --- a/dispatcher/tests/test_problem_extended.cpp +++ b/dispatcher/tests/test_problem_extended.cpp @@ -19,7 +19,7 @@ class ProblemDimensionInferenceTest : public ::testing::Test TEST_F(ProblemDimensionInferenceTest, FromAB_Basic) { - // A: M×K (1024×512), B: K×N (512×2048) + // A: MxK (1024x512), B: KxN (512x2048) auto problem = Problem::from_ab(1024, 512, 512, 2048); EXPECT_EQ(problem.M, 1024); @@ -30,7 +30,7 @@ TEST_F(ProblemDimensionInferenceTest, FromAB_Basic) TEST_F(ProblemDimensionInferenceTest, FromDimensions_Valid) { - // A: 1024×512, B: 512×2048, C: 1024×2048 + // A: 1024x512, B: 512x2048, C: 1024x2048 auto problem = Problem::from_dimensions(1024, 512, 512, 2048, 1024, 2048); EXPECT_EQ(problem.M, 1024); @@ -55,7 +55,7 @@ TEST_F(ProblemDimensionInferenceTest, FromShapes_WithC) TEST_F(ProblemDimensionInferenceTest, FromShapes_TransposedA) { - // A stored as K×M (transposed) + // A stored as KxM (transposed) TensorShape A{512, 1024, true}; TensorShape B{512, 2048, false}; TensorShape C{1024, 2048, false}; @@ -70,7 +70,7 @@ TEST_F(ProblemDimensionInferenceTest, FromShapes_TransposedA) TEST_F(ProblemDimensionInferenceTest, FromShapes_TransposedB) { TensorShape A{1024, 512, false}; - // B stored as N×K (transposed) + // B stored as NxK (transposed) TensorShape B{2048, 512, true}; TensorShape C{1024, 2048, false}; diff --git a/dispatcher/tests/test_real_kernel_multi_size.cpp b/dispatcher/tests/test_real_kernel_multi_size.cpp index f23f684631..79282da557 100644 --- a/dispatcher/tests/test_real_kernel_multi_size.cpp +++ b/dispatcher/tests/test_real_kernel_multi_size.cpp @@ -187,7 +187,7 @@ int main() for(const auto& r : results) { char size_str[32]; - snprintf(size_str, sizeof(size_str), "%4d×%4d×%4d", r.M, r.N, r.K); + snprintf(size_str, sizeof(size_str), "%4dx%4dx%4d", r.M, r.N, r.K); printf(" %-14s | %9.4f | %6.2f | %7.2f%% | %s\n", size_str, diff --git a/dispatcher/tests/test_real_kernel_performance.cpp b/dispatcher/tests/test_real_kernel_performance.cpp index ff3d635968..29c7c80ac3 100644 --- a/dispatcher/tests/test_real_kernel_performance.cpp +++ b/dispatcher/tests/test_real_kernel_performance.cpp @@ -144,7 +144,7 @@ int main() all_passed = all_passed && passed; char size_label[32]; - snprintf(size_label, sizeof(size_label), "%s %d³", label, M); + snprintf(size_label, sizeof(size_label), "%s %d^3", label, M); printf(" %-9s | %9.4f | %6.2f | %9.1f | %s\n", size_label,