From ceddfcc13cd32f9431b96cdbbff15906b05f02cc Mon Sep 17 00:00:00 2001
From: Copilot <198982749+Copilot@users.noreply.github.com>
Date: Wed, 8 Apr 2026 14:58:53 +0800
Subject: [PATCH 01/34] [CK_TILE] Refine FMHA Readme (#6003)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Updates the FMHA README to document fp8 precision support more
accurately, replacing the outdated "experimental" section and incomplete
CLI arg descriptions.
## Changes
- **`-prec` arg**: expanded supported values from `fp16/bf16/fp8/bf8` →
`fp32/fp16/bf16/fp8/fp8bf16/fp8fp32/mxfp8/mxfp4`
- **`-qscale` arg**: replaced single-line `1: per-tensor quantization`
with all four modes: `pt/1`, `bs/2`, `kvbs/3`, `mx/4`
- **FP8 support section**: replaced "FP8 experimental support" paragraph
with:
- Supported targets: gfx942/gfx950 + ROCm 6.0+
- Table distinguishing `fp8` / `fp8bf16` / `fp8fp32` by Q/K/V input type
and output type
- Table for all `-qscale` modes with descriptions
- Note that `-vlayout=r` (`seqlen*hdim` for V) is the only supported
layout for fp8 types
Original prompt
Please open a PR against base branch `develop` in repository
`ROCm/rocm-libraries` applying the following documentation updates
within the composable kernel path.
## Scope
Update the file:
- `projects/composablekernel/example/ck_tile/01_fmha/README.md`
## Changes to apply
Apply the combined edits described in the diffs below (two consecutive
patches). Ensure the final file content includes **both** sets of
changes.
### Patch 1
- In the CLI args section:
- Update `-qscale` description lines to include:
- `pt or 1, per-tensor scale`
- `bs or 2, block scale`
- `kvbs or 3, Q per-tensor, K/V per-page block scale`
- `mx or 4, microscaling (exclusively for mxfp8/mxfp4)`
- Update `-prec` supported data types from `fp16/bf16/fp8/bf8` to
`fp32/fp16/bf16/fp8/fp8bf16/fp8fp32/mxfp8/mxfp4`.
- Replace the existing "FP8 experimental support" section with an "FP8
support" section stating:
- FP8 FMHA kernels supported on gfx942/gfx950 with ROCm 6.0+
- Precision selectable via `-prec=fp8` (or `fp8bf16`, `fp8fp32`) for
`tile_example_fmha_fwd`
- Add a table describing `-qscale` modes:
- `n` or `0`: No quantization scale (default)
- `pt` or `1`: Per-tensor quantization scale
- `bs` or `2`: Per-block quantization scale
- `kvbs` or `3`: Q per-tensor + K/V per-page block scale
- `mx` or `4`: Microscaling (MX format), exclusively for `mxfp8` and
`mxfp4`
- Add/keep note that currently only `-vlayout=r` (`seqlen*hdim` for V
matrix) is supported for fp8 data types.
### Patch 2
Further refine the "FP8 support" paragraph to explain the difference
between `fp8`, `fp8bf16`, and `fp8fp32` via a table:
| `-prec` value | Q/K/V input type | Output type | Description |
|---|---|---|---|
| `fp8` | fp8 | fp8 | Fully fp8: both inputs and output are in fp8 |
| `fp8bf16` | fp8 | bf16 | Mixed precision: fp8 inputs, bf16 output —
useful when the consumer expects a wider-range output format |
| `fp8fp32` | fp8 | fp32 | Mixed precision: fp8 inputs, fp32 output —
highest-precision output, suitable for debugging or further fp32
processing |
Keep the rest of the `-qscale` table and the `-vlayout=r` limitation
note.
## Notes
- PR title must be: `[CK_TILE] Add fp8 in FMHA readme`
- Ensure markdown formatting renders correctly (tables, code
formatting).
- Only modify the file listed above.
The following is the prior conversation context from the user's chat
exploration (may be truncated):
User: 能幫我上這個pr嗎 在composable kernel裡的路徑
diff --git a/example/ck_tile/01_fmha/README.md
b/example/ck_tile/01_fmha/README.md
index 0b526f4e9fc..1627435863b 100644
--- a/example/ck_tile/01_fmha/README.md
+++ b/example/ck_tile/01_fmha/README.md
@@ -62,14 +62,17 @@ args:
-d_v head dim for v, -1 means equal to d (default:-1)
-scale_s scale factor of S. 0 means equal to 1/sqrt(hdim). (default:0)
-qscale n or 0, no scaling (default:n)
- 1: per-tensor quantization.
+ pt or 1, per-tensor scale
+ bs or 2, block scale
+ kvbs or 3, Q per-tensor, K/V per-page block scale
+ mx or 4, microscaling (exclusively for mxfp8/mxfp4)
-iperm permute input (default:1)
if true, will be b*h*s*d, else b*s*h*d
-operm permute output (default:1)
-bias n or 0, no bias (default:n)
e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2,
b*h*s*s
a(libi) or 2, alibi with 1*h. a:1, b*h
- -prec data type. fp16/bf16/fp8/bf8 (default:fp16)
+ -prec data type. fp32/fp16/bf16/fp8/fp8bf16/fp8fp32/mxfp8/mxfp4
(default:fp16)
-mask 0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')
(default:0)
't', top-left causal mask, 'b', bottom-r causal mask
't:l,r', top-left sliding window attn(swa) with FA style left right size
@@ -161,7 +164,17 @@ We support sequence padding and variable-length
processing in both batch and gro
Both approaches optimize memory access patterns while supporting
flexible sequence length requirements commonly found in transformer
inference scenarios.
-## FP8 experimental support
-As described in [this
blog](https://blog.hippoml.com/8bit-hippoattention-up-to-3x-faster-compared-to-flashattentionv2-8f9def90b482),
we have an experimental support for fp8 fmha kernels, you can evaluate
the performance by setting the arg `-prec=fp8` to the
`tile_example_fmha_fwd`, on a gfx942 machine and ROCm 6.0+.
+## FP8 support
+FP8 FMHA kernels are supported on gfx942/gfx950 machines with ROCm
6.0+. You can select fp8 precision by setting the arg `-prec=fp8` (or
`fp8bf16`, `fp8fp32`) to the `tile_example_fmha_fwd`.
-Currently we only support `-vlayout=r`( `seqlen*hdim` for V matrix) for
fp8 and fp8bf16 now. Full feature support will come later.
+The following quantization scale modes are available via `-qscale`:
+
+| `-qscale` value | Description |
+|---|---|
+| `n` or `0` | No quantization sca...
*This pull request was created from Copilot chat.*
>
---------
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: asleepzzz <4926646+asleepzzz@users.noreply.github.com>
Co-authored-by: asleepzzz
---
example/ck_tile/01_fmha/README.md | 29 ++++++++++++++++++++++++-----
1 file changed, 24 insertions(+), 5 deletions(-)
diff --git a/example/ck_tile/01_fmha/README.md b/example/ck_tile/01_fmha/README.md
index 0b526f4e9f..2aaaa45a9a 100644
--- a/example/ck_tile/01_fmha/README.md
+++ b/example/ck_tile/01_fmha/README.md
@@ -62,14 +62,17 @@ args:
-d_v head dim for v, -1 means equal to d (default:-1)
-scale_s scale factor of S. 0 means equal to 1/sqrt(hdim). (default:0)
-qscale n or 0, no scaling (default:n)
- 1: per-tensor quantization.
+ pt or 1, per-tensor scale
+ bs or 2, block scale
+ kvbs or 3, Q per-tensor, K/V per-page block scale, only in batch_prefill
+ mx or 4, microscaling (exclusively for mxfp8/mxfp4)
-iperm permute input (default:1)
if true, will be b*h*s*d, else b*s*h*d
-operm permute output (default:1)
-bias n or 0, no bias (default:n)
e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s
a(libi) or 2, alibi with 1*h. a:1, b*h
- -prec data type. fp16/bf16/fp8/bf8 (default:fp16)
+ -prec data type. fp32/fp16/bf16/fp8/fp8bf16/fp8fp32/mxfp8/mxfp4 (default:fp16)
-mask 0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b') (default:0)
't', top-left causal mask, 'b', bottom-r causal mask
't:l,r', top-left sliding window attn(swa) with FA style left right size
@@ -161,7 +164,23 @@ We support sequence padding and variable-length processing in both batch and gro
Both approaches optimize memory access patterns while supporting flexible sequence length requirements commonly found in transformer inference scenarios.
-## FP8 experimental support
-As described in [this blog](https://blog.hippoml.com/8bit-hippoattention-up-to-3x-faster-compared-to-flashattentionv2-8f9def90b482), we have an experimental support for fp8 fmha kernels, you can evaluate the performance by setting the arg `-prec=fp8` to the `tile_example_fmha_fwd`, on a gfx942 machine and ROCm 6.0+.
+## FP8 support
+FP8 FMHA kernels are supported on gfx942/gfx950 machines with ROCm 6.0+. Three fp8-based precision modes are available via `-prec`:
-Currently we only support `-vlayout=r`( `seqlen*hdim` for V matrix) for fp8 and fp8bf16 now. Full feature support will come later.
+| `-prec` value | Q/K/V input type | Output type | Description |
+|---|---|---|---|
+| `fp8` | fp8 | fp8 | Fully fp8: both inputs and output are in fp8 |
+| `fp8bf16` | fp8 | bf16 | Mixed precision: fp8 inputs, bf16 output — useful when the consumer expects a wider-range output format |
+| `fp8fp32` | fp8 | fp32 | Mixed precision: fp8 inputs, fp32 output — highest-precision output, suitable for debugging or further fp32 processing |
+
+The following quantization scale modes are available via `-qscale`:
+
+| `-qscale` value | Description |
+|---|---|
+| `n` or `0` | No quantization scale (default) |
+| `pt` or `1` | Per-tensor quantization scale — a single scale factor is applied to the entire tensor |
+| `bs` or `2` | Per-block quantization scale — a scale factor is applied per block of elements |
+| `kvbs` or `3` | Q per-tensor + K/V per-page block scale (batch_prefill only) |
+| `mx` or `4` | Microscaling (MX format), exclusively for `mxfp8` and `mxfp4` data types |
+
+Currently only `-vlayout=r` (`seqlen*hdim` for V matrix) is supported for fp8 data types.
From c9539824349869bb11d78915036adbb56ea1a18f Mon Sep 17 00:00:00 2001
From: Yaswanth Raparti <113389104+yraparti@users.noreply.github.com>
Date: Wed, 8 Apr 2026 02:54:56 -0700
Subject: [PATCH 02/34] [CK] [CK Tile] Improved ci_safety_check in smart-build
infrastructure (#6215)
## Motivation
The two-dot syntax (origin/develop..HEAD) is more conservative and
catches a broader set of changes when PRs merge develop branch. While
three-dot syntax shows only PR-specific changes, two-dot ensures we
don't miss any files that differ between develop and the PR branch,
including files modified in both the PR and merged develop commits.
This conservative approach prioritizes catching all potential issues
over CI efficiency, which is appropriate for build system change
detection.
# Technical Details:
- Switched to two-dot (..) syntax in ci_safety_check.sh
- Update comments to clarify the intentional use of two-dot syntax
- Maintain consistency across both CHANGE_ID branches
- Trigger full build when any of the following changes
- `Dockerfile|Jenkinsfile|CMakePresets\.json|script/dependency-parser/`
## Test Plan
Tested with PR 6200 which has multiple merge-commits.
## Test Result
It detects 43 new tests compared to 3-dot scheme.
## Submission Checklist
- [x ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
---------
Co-authored-by: andrew clark
---
script/dependency-parser/ci_safety_check.sh | 36 +++++++++++++--------
script/dependency-parser/validate_pr.sh | 2 +-
2 files changed, 24 insertions(+), 14 deletions(-)
diff --git a/script/dependency-parser/ci_safety_check.sh b/script/dependency-parser/ci_safety_check.sh
index adfe7d7f09..bd19a0630f 100755
--- a/script/dependency-parser/ci_safety_check.sh
+++ b/script/dependency-parser/ci_safety_check.sh
@@ -18,8 +18,8 @@
# CHANGE_TARGET - Base branch for PR builds (set by Jenkins Multibranch Pipeline)
#
# Note: CHANGE_ID may not be set even for PR builds if Jenkins job is not
-# configured as Multibranch Pipeline. Script uses three-dot git diff syntax
-# to correctly detect PR changes regardless of CHANGE_ID availability.
+# configured as Multibranch Pipeline. Script uses two-dot git diff syntax
+# to detect PR changes regardless of CHANGE_ID availability.
#
# Manual override (set by developer/admin if needed):
# DISABLE_SMART_BUILD - Set to "true" to force full build
@@ -48,19 +48,29 @@ fi
# 3. Force full build if CMakeLists.txt or cmake/ configuration changed
# Always compare against base branch (not consecutive commits) to avoid false positives from merge commits
-# Three-dot syntax (...) only shows changes actually made in the PR, not changes from merged develop branch
-if [ -n "$CHANGE_ID" ]; then
- # This is a PR build (CHANGE_ID set by Jenkins Multibranch Pipeline)
- CHANGED_FILES=$(git diff --name-only origin/${BASE_BRANCH}...HEAD 2>/dev/null || echo "")
-else
- # Fallback: Works for both branch builds and PRs without CHANGE_ID
- # Use three-dot syntax to avoid including merge commit changes from develop
- CHANGED_FILES=$(git diff --name-only origin/${BASE_BRANCH}...HEAD 2>/dev/null || echo "")
-fi
+# Two-dot syntax (..) compares current state against base branch
+# Note: This includes merged changes from develop, which is conservative but safe (catches all potentially affected files)
+CHANGED_FILES=$(git diff --name-only origin/${BASE_BRANCH}..HEAD 2>/dev/null || echo "")
-if echo "$CHANGED_FILES" | grep -qE "(CMakeLists\.txt|cmake/.*\.cmake)"; then
+# Comprehensive pattern for build/infrastructure files that require full build:
+# - CMake: CMakeLists.txt, *.cmake, *.cmake.in, CMakePresets.json
+# - Docker: Dockerfile*, docker-compose*
+# - CI/CD: Jenkinsfile, .github/, .gitlab-ci.yml, .pre-commit-config.yaml, .readthedocs.yaml
+# - Scripts: script/ directory (cmake, dependency-parser, build utilities)
+# - Compiler: .clang-format, .clang-tidy
+# - Python: setup.py, pyproject.toml, requirements*.txt
+BUILD_INFRA_PATTERN="(CMakeLists\.txt"
+BUILD_INFRA_PATTERN="${BUILD_INFRA_PATTERN}|\.cmake$|\.cmake\.in$|CMakePresets\.json"
+BUILD_INFRA_PATTERN="${BUILD_INFRA_PATTERN}|Dockerfile|docker-compose"
+BUILD_INFRA_PATTERN="${BUILD_INFRA_PATTERN}|Jenkinsfile|\.github/|\.gitlab-ci\.yml"
+BUILD_INFRA_PATTERN="${BUILD_INFRA_PATTERN}|\.pre-commit-config\.yaml|\.readthedocs\.yaml"
+BUILD_INFRA_PATTERN="${BUILD_INFRA_PATTERN}|script/"
+BUILD_INFRA_PATTERN="${BUILD_INFRA_PATTERN}|\.clang-format|\.clang-tidy"
+BUILD_INFRA_PATTERN="${BUILD_INFRA_PATTERN}|setup\.py|pyproject\.toml|requirements.*\.txt)"
+
+if echo "$CHANGED_FILES" | grep -qE "${BUILD_INFRA_PATTERN}"; then
FORCE_FULL_BUILD=true
- REASON="build system configuration changed (CMakeLists.txt or cmake/*.cmake)"
+ REASON="build system configuration changed"
fi
# 4. Force full build if dependency cache is older than 7 days
diff --git a/script/dependency-parser/validate_pr.sh b/script/dependency-parser/validate_pr.sh
index 61f185af8d..f8c77a2811 100755
--- a/script/dependency-parser/validate_pr.sh
+++ b/script/dependency-parser/validate_pr.sh
@@ -189,7 +189,7 @@ git log --oneline -5
log_section "Step 3: Analyze Changed Files"
log_info "Files changed vs $BASE_BRANCH:"
-CHANGED_FILES=$(git diff --name-only ${BASE_BRANCH}...HEAD -- projects/composablekernel)
+CHANGED_FILES=$(git diff --name-only ${BASE_BRANCH}..HEAD -- projects/composablekernel)
NUM_FILES=$(echo "$CHANGED_FILES" | wc -l)
echo "$CHANGED_FILES" | head -20
if [ "$NUM_FILES" -gt 20 ]; then
From 65ad35becde33434565aa0533e4f89249599b4cf Mon Sep 17 00:00:00 2001
From: Hosang Yoon <156028780+hyoon1@users.noreply.github.com>
Date: Wed, 8 Apr 2026 10:51:53 -0400
Subject: [PATCH 03/34] [CK_TILE] Optimize FMHA head-dim padded path on
gfx11/gfx12 (#6156)
## Motivation
On gfx11/gfx12, FMHA forward kernels that require head-dim padding show
a large performance drop compared to the exact-head-dim path. In
practice, padded cases such as `HDIM=72` and `HDIM=80` were falling too
far off the fast path.
This PR improves padded-head-dim FMHA performance on gfx11/gfx12 while
keeping the behavior for other GPUs unchanged.
## Technical Details
- Add/scope a dedicated padded-head-dim (`qr_hpad`) FMHA forward path
for gfx11/gfx12.
- For `receipt=0`, keep support conservative and only enable the padded
fast path for vector-safe cases (`head_dim % 8 == 0`), matching the
existing assumption used on other GPUs.
- Move `v_prefetch` later only for the head-dim-padded path on
gfx11/gfx12. This reduces live ranges and removes the register-spill
behavior seen in the earlier scheduling.
- Enable the buffer-load OOB check offset trick for the padded path on
gfx11/gfx12.
## Test Plan
./build/bin/tile_example_fmha_fwd -prec=bf16 -mode={0/1} -b=1 -h=16
-d={72/80} -s={seqlen} -s_k={seqlen} -lse=0 -iperm={0/1} -operm={0/1}
## Test Result
Observed padded-head-dim performance improvements for HDIM=72/80:
- gfx11: about ~3.5x
- gfx1151: about ~2.0x
- gfx12: about ~1.3x
## Submission Checklist
- [ ] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
---
.../ck_tile/01_fmha/codegen/cpp_symbol_map.py | 2 +
.../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 83 +++++++++++++++++--
.../pipeline/block_fmha_pipeline_enum.hpp | 7 ++
.../pipeline/block_fmha_pipeline_qr_ks_vs.hpp | 78 ++++++++++++-----
4 files changed, 144 insertions(+), 26 deletions(-)
diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
index e9ae11fb5f..79fe6492a6 100644
--- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
+++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py
@@ -139,6 +139,7 @@ LAYOUT_MAP = {"row": "true", "col": "false"}
PIPELINE_MAP = {
"qr": "ck_tile::BlockFmhaPipelineQRKSVS",
+ "qr_hpad": "ck_tile::BlockFmhaPipelineQRKSVSHpad",
"qr_async": "ck_tile::BlockFmhaPipelineQRKSVSAsync",
"qs": "ck_tile::BlockFmhaPipelineQSKSVS",
"qr_async_trload": "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload",
@@ -147,6 +148,7 @@ PIPELINE_MAP = {
PIPELINE_ENUM_MAP = {
"qr": "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
+ "qr_hpad": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_HPAD",
"qr_async": "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
"qr_nwarp_sshuffle": "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
"qs": "ck_tile::BlockFmhaPipelineEnum::QSKSVS",
diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
index 42e2d1f487..c64a19104e 100644
--- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
+++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py
@@ -60,6 +60,22 @@ FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
#include "fmha_fwd.hpp"
"""
+FMHA_FWD_KERNEL_HEADER_QR_HPAD = """// SPDX-License-Identifier: MIT
+// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.\n
+// auto generated by generate.py
+#if defined(__HIP_DEVICE_COMPILE__) && \
+ (defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \
+ defined(__gfx1103__) || defined(__gfx1150__) || defined(__gfx1151__) || \
+ defined(__gfx1152__) || defined(__gfx1153__) || defined(__gfx11_generic__) || \
+ defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__))
+#if !defined(CK_TILE_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK)
+#define CK_TILE_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK 1
+#endif
+#endif
+#include "ck_tile/ops/fmha/block/variants.hpp"
+#include "fmha_fwd.hpp"
+"""
+
FMHA_FWD_KERNEL_BODY_TEMPLATE = """
#include
@@ -300,7 +316,7 @@ class FmhaFwdApiTrait:
return "true" # always support
else:
return "true"
- elif self.pipeline_tag in ["qr", "qs"]:
+ elif self.pipeline_tag in ["qr", "qr_hpad", "qs"]:
if self.spad == "t":
return f"true /*a.seqlen_q % {self.bm0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
else:
@@ -323,7 +339,7 @@ class FmhaFwdApiTrait:
return f"(a.cu_seqlen_k_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)"
else:
return f"(a.cu_seqlen_k_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)"
- elif self.pipeline_tag in ["qr", "qs"]:
+ elif self.pipeline_tag in ["qr", "qr_hpad", "qs"]:
if self.skpad == "t":
return f"true /*a.seqlen_k % {self.bn0} != 0*/" # TODO: order of get_pipelines() matters! (ugly)
else:
@@ -344,6 +360,11 @@ class FmhaFwdApiTrait:
return f"a.hdim_q % {vec} == 0"
else:
assert False
+ elif self.pipeline_tag == "qr_hpad":
+ if self.dpad == "t":
+ return "a.hdim_q % 8 == 0"
+ else:
+ assert False
elif self.pipeline_tag in ["qr", "qs", "qr_async_trload", "qr_async_trload_v3"]:
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
if self.dpad == "t":
@@ -361,6 +382,11 @@ class FmhaFwdApiTrait:
return f"a.hdim_v % {vec} == 0"
else:
assert False
+ elif self.pipeline_tag == "qr_hpad":
+ if self.dvpad == "t":
+ return "a.hdim_v % 8 == 0"
+ else:
+ assert False
elif self.pipeline_tag in ["qr", "qs", "qr_async_trload", "qr_async_trload_v3"]:
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
if self.dvpad == "t":
@@ -634,6 +660,7 @@ class FmhaFwdKernel:
F_pipeline: FmhaFwdPipeline
_KERNEL_HEADER: ClassVar[str] = FMHA_FWD_KERNEL_HEADER
+ _KERNEL_HEADER_QR_HPAD: ClassVar[str] = FMHA_FWD_KERNEL_HEADER_QR_HPAD
_KERNEL_BODY_TEMPLATE: ClassVar[str] = FMHA_FWD_KERNEL_BODY_TEMPLATE
@classmethod
@@ -643,6 +670,12 @@ class FmhaFwdKernel:
else:
return "ck_tile::FmhaFwdKernel"
+ @classmethod
+ def _get_kernel_header(cls, pipeline_tag):
+ if pipeline_tag == "qr_hpad":
+ return cls._KERNEL_HEADER_QR_HPAD
+ return cls._KERNEL_HEADER
+
@classmethod
def _get_cpp_kargs_creator_func_name(cls, pipeline_tag):
if pipeline_tag == "qr_async_trload_v3":
@@ -651,7 +684,9 @@ class FmhaFwdKernel:
return "fmha_fwd_create_kargs_and_grids"
def render(self) -> str:
- return type(self)._KERNEL_HEADER + type(self)._KERNEL_BODY_TEMPLATE.format(
+ return type(self)._get_kernel_header(self.F_pipeline.tag) + type(
+ self
+ )._KERNEL_BODY_TEMPLATE.format(
F_kname=self.name,
F_arch=self.F_arch,
F_hdim=self.F_hdim,
@@ -1144,6 +1179,37 @@ class KernelComponentFactoryGfx11(CompatibilityRuleFactory):
def supported_dtypes(cls) -> Tuple[str]:
return cls._DT_FP16_BF16
+ @classmethod
+ def get_rules(cls) -> List[CompatibilityRule]:
+ rules = super().get_rules()
+
+ # For gfx11 fp16/bf16 d128, use dpad=dvpad=t for the 64x32 tile:
+ # the exact-hdim variant (dpad=dvpad=f) is much slower here.
+ def check_d128_tile_pipeline(
+ problem_ctx: ProblemContext, kernel_ctx: KernelContext
+ ) -> bool:
+ if problem_ctx.dtype not in cls._DT_FP16_BF16:
+ return True
+
+ if (problem_ctx.hdim, problem_ctx.hdim_v) != (128, 128):
+ return True
+
+ is_64x32_tile = kernel_ctx.tile.F_bm0 == 64 and kernel_ctx.tile.F_bn0 == 32
+ pads_hdim = (
+ kernel_ctx.pipeline.F_dpad == "t" and kernel_ctx.pipeline.F_dvpad == "t"
+ )
+ exact_hdim = (
+ kernel_ctx.pipeline.F_dpad == "f" and kernel_ctx.pipeline.F_dvpad == "f"
+ )
+
+ if is_64x32_tile:
+ return pads_hdim
+
+ return exact_hdim
+
+ rules.append(check_d128_tile_pipeline)
+ return rules
+
@classmethod
def get_hdim_tile_size_dict(cls, dtype: str) -> Optional[dict]:
if dtype in cls._DT_FP16_BF16:
@@ -1152,7 +1218,8 @@ class KernelComponentFactoryGfx11(CompatibilityRuleFactory):
( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 4096")),
FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
- (128, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 4096")),
+ (128, 128) : [FmhaFwdTileSize( 64, 32, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, 6, CppConstraint("a.hdim_q != 128 || a.hdim_v != 128")),
+ FmhaFwdTileSize( 64, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint("a.max_seqlen_q < 4096")),
FmhaFwdTileSize(128, 64, 32, 128, 32, 128, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, 6)],
(192, 128) : [FmhaFwdTileSize( 64, 64, 32, 128, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
(256, 256) : [FmhaFwdTileSize(128, 64, 32, 256, 32, 256, 8, 1, 1, 8, 1, 1, 16, 16, 16, 16, 16, 16, 6)]
@@ -1179,7 +1246,9 @@ class KernelComponentFactoryGfx11(CompatibilityRuleFactory):
# Keep only ttff/tttt for gfx11: ffff path is often similar or worse
# pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
- pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
+ pipelines.append(FmhaFwdPipeline("qr_hpad", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
+ if receipt == 1:
+ pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
return pipelines
@@ -1251,7 +1320,9 @@ class KernelComponentFactoryGfx12(CompatibilityRuleFactory):
):
# pipelines.append(FmhaFwdPipeline("qr", "row", "f", "f", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "f", "f", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
- pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
+ pipelines.append(FmhaFwdPipeline("qr_hpad", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
+ if receipt == 1:
+ pipelines.append(FmhaFwdPipeline("qr", "row", "t", "t", "t", "t", logits, bias, lse, dropout, qscale, mask, skip, "f", sink)) # fmt: skip
elif dtype in cls._DT_FP8_FP8BF16 or dtype in cls._DT_FP8FP32:
# no need lse/dropout kernels
for logits, qscale, mask, bias in itertools.product(
diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp
index 659bdd995b..a1a98867c6 100644
--- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp
+++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp
@@ -13,6 +13,7 @@ enum class BlockFmhaPipelineEnum
QSKSVS,
QRKSVS_ASYNC_TRLOAD,
QRKSVS_ASYNC_TRLOAD_V3,
+ QRKSVS_HPAD,
};
template
@@ -40,4 +41,10 @@ struct BlockFmhaPipelineEnumToStr
static constexpr const char* name = "qr_async_trload";
};
+template <>
+struct BlockFmhaPipelineEnumToStr
+{
+ static constexpr const char* name = "qr_hpad";
+};
+
} // namespace ck_tile
diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
index b207c62181..48c79177d4 100644
--- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
+++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp
@@ -14,7 +14,9 @@
namespace ck_tile {
// This pipeline is qkv all located in LDS
-template
+template
struct BlockFmhaPipelineQRKSVS
{
using Problem = remove_cvref_t;
@@ -54,17 +56,18 @@ struct BlockFmhaPipelineQRKSVS
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
- static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
- static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
- static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
- static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
- static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
- static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
- static constexpr auto BiasEnum = Problem::BiasEnum;
- static constexpr bool kStoreLSE = Problem::kStoreLSE;
- static constexpr bool kHasDropout = Problem::kHasDropout;
- static constexpr auto QScaleEnum = Problem::QScaleEnum;
- static constexpr bool kHasSink = Problem::kHasSink;
+ static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
+ static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
+ static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
+ static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
+ static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
+ static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
+ static constexpr auto BiasEnum = Problem::BiasEnum;
+ static constexpr bool kStoreLSE = Problem::kStoreLSE;
+ static constexpr bool kHasDropout = Problem::kHasDropout;
+ static constexpr auto QScaleEnum = Problem::QScaleEnum;
+ static constexpr bool kHasSink = Problem::kHasSink;
+ static constexpr bool kPaddedVecLoadStore = PaddedVecLoadStore_;
static constexpr ck_tile::index_t kQKScaleGranularity = Problem::kQKScaleGranularity;
static constexpr ck_tile::index_t kVScaleGranularity = Problem::kVScaleGranularity;
@@ -80,23 +83,29 @@ struct BlockFmhaPipelineQRKSVS
(kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
!kHasLogitsSoftCap)) ||
(!CK_TILE_FMHA_FWD_FAST_EXP2 && !kHasLogitsSoftCap));
+ static_assert(!kPaddedVecLoadStore || (kPadHeadDimQ && kPadHeadDimV),
+ "padded vector load/store fast path only applies to padded head-dim kernels");
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
- static constexpr index_t kAlignmentQ = kPadHeadDimQ ? numeric_traits::PackedSize
- : Policy::template GetAlignmentQ();
- static constexpr index_t kAlignmentK = kPadHeadDimQ ? numeric_traits::PackedSize
- : Policy::template GetAlignmentK();
+ static constexpr index_t kAlignmentQ = (kPadHeadDimQ && !kPaddedVecLoadStore)
+ ? numeric_traits::PackedSize
+ : Policy::template GetAlignmentQ();
+ static constexpr index_t kAlignmentK = (kPadHeadDimQ && !kPaddedVecLoadStore)
+ ? numeric_traits::PackedSize
+ : Policy::template GetAlignmentK();
static constexpr index_t kAlignmentV = []() {
if constexpr(std::is_same_v)
- return kPadHeadDimV ? 1 : Policy::template GetAlignmentV();
+ return (kPadHeadDimV && !kPaddedVecLoadStore)
+ ? 1
+ : Policy::template GetAlignmentV();
else
return kPadSeqLenK ? numeric_traits::PackedSize
: Policy::template GetAlignmentV();
}();
static constexpr index_t kAlignmentO =
- kPadHeadDimV ? 1 : Policy::template GetAlignmentO();
+ (kPadHeadDimV && !kPaddedVecLoadStore) ? 1 : Policy::template GetAlignmentO();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias();
static constexpr index_t kAlignmentRandVal =
@@ -548,8 +557,25 @@ struct BlockFmhaPipelineQRKSVS
});
}
- const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile
- { // tail
+ auto v_prefetch = decltype(load_tile(v_dram_window)){};
+ enum class VPrefetchPoint
+ {
+ BeforeGemm0Tail,
+ AfterGemm0Tail,
+ AfterSoftmax
+ };
+
+#if defined(__gfx11__) || defined(__gfx12__)
+ constexpr auto kVPrefetch =
+ kPadHeadDimV ? VPrefetchPoint::AfterSoftmax : VPrefetchPoint::AfterGemm0Tail;
+#else
+ constexpr auto kVPrefetch = VPrefetchPoint::BeforeGemm0Tail;
+#endif
+ if constexpr(kVPrefetch == VPrefetchPoint::BeforeGemm0Tail)
+ {
+ load_tile(v_prefetch, v_dram_window); // prefetch load v tile
+ }
+ { // tail
block_sync_lds();
run_gemm_0(number{});
block_sync_lds();
@@ -562,6 +588,10 @@ struct BlockFmhaPipelineQRKSVS
run_gemm_0(number{});
}
+ if constexpr(kVPrefetch == VPrefetchPoint::AfterGemm0Tail)
+ {
+ load_tile(v_prefetch, v_dram_window);
+ }
// dequant
auto s_acc_element_func_ = [&s_acc_element_func, k_descale]() {
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
@@ -819,6 +849,11 @@ struct BlockFmhaPipelineQRKSVS
randval_ptr, seq_offset, p_compute, randval_dram_window);
}
+ if constexpr(kVPrefetch == VPrefetchPoint::AfterSoftmax)
+ {
+ load_tile(v_prefetch, v_dram_window);
+ }
+
block_sync_lds();
if constexpr(std::is_same_v)
{
@@ -1098,4 +1133,7 @@ struct BlockFmhaPipelineQRKSVS
}
};
+template
+using BlockFmhaPipelineQRKSVSHpad = BlockFmhaPipelineQRKSVS;
+
} // namespace ck_tile
From 40290297cdc6532aa2da9f91cb8ace60f00c5c77 Mon Sep 17 00:00:00 2001
From: Vidyasagar Ananthan
Date: Thu, 9 Apr 2026 10:38:33 -0700
Subject: [PATCH 04/34] [CK] [CK_Tile] Add GroupConv to Kernel Dispatcher
(#5168)
## Motivation
This PR adds CK Tile group convolution (forward, backward-data,
backward-weight) support to the kernel dispatcher, matching and unifying
with the existing dispatcher GEMM infrastructure in architecture and
usability. The dispatcher provides a unified kernel dispatch system with
both C++ and Python frontends, and until now only supported GEMM
operations. This PR enables framework integrators to use the same
declarative kernel workflow for convolutions as they do for GEMM:
declare kernels, build a registry JIT, select kernels within the
registry at runtime, and dispatch to GPU. Future PRs will include
runtime kernel selection heuristics for autotuning of kernel parameters
based on (problem, hardware arch).
## Technical Details
Grouped convolution support has been added to the CK Tile Dispatcher
with generated_conv_backend.hpp enabling dispatcher.run(in, wei, out,
problem) for all 6 conv variants (fwd/bwdd/bwdw x 2D/3D), runtime
heuristic kernel selection, and GroupedConvKernelKey with full
ConvConfigBase fields. Python side adds parallel JIT via
registry.build(max_workers) and heuristic registry.select(). Includes 7
C++ and 6 Python examples covering all directions with CPU reference
validation, and shared infrastructure improvements (BaseRegistry CRTP,
structured exceptions). As a sanity check, JIT compile times for a
single kernel remains the same and for multiple kernels there is better
parallelism:
Kernels | 1 worker | 8 workers
1 | 7.7 s | 7.7 s
2 | 15.9 s | 8.2 s
4 | 33.4 s | 9.7 s
6 | 52.3 s | 10.2 s
## Test Plan
145 ephemeral unit tests have been added to test basic functionality.
All 30 examples/integration tests run end-to-end on gfx950 (MI350): 7
C++ conv, 7 C++ GEMM, 6 Python conv, 10 Python GEMM. CPU reference
validation for forward, backward-data, and backward-weight (2D) in both
C++ and Python examples pass.
## Test Result
30 examples pass. Peak performance: 132 TFLOPS (Batch-32 forward 56x56),
53 TFLOPS (pointwise 1x1). CPU reference accuracy: max_abs_diff < 0.002
for all directions (fp16 vs fp32 reference).
## Submission Checklist
- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
---------
Co-authored-by: Yaswanth Raparti <113389104+yraparti@users.noreply.github.com>
---
dispatcher/README.md | 141 +-
dispatcher/bindings/README.md | 24 +-
dispatcher/bindings/ctypes/CMakeLists.txt | 4 +-
.../bindings/ctypes/conv_bwdw_ctypes_lib.cpp | 4 +-
.../bindings/ctypes/conv_ctypes_lib.cpp | 643 +++---
dispatcher/codegen/ADDING_NEW_GPU.md | 16 +-
dispatcher/codegen/README.md | 47 +-
dispatcher/codegen/codegen_common.py | 350 ++++
.../generate_dispatcher_registration.py | 8 +-
.../codegen/generate_kernel_wrappers.py | 8 +-
dispatcher/codegen/kernel_config_loader.py | 32 +-
dispatcher/codegen/unified_gemm_codegen.py | 236 +--
.../codegen/unified_grouped_conv_codegen.py | 1757 ++++++++++++++++
dispatcher/examples/CMakeLists.txt | 70 +-
dispatcher/examples/README.md | 45 +-
.../examples/gemm/cpp/02_multi_size.cpp | 20 +-
.../examples/gemm/cpp/07_gfx950_minimal.cpp | 191 ++
dispatcher/examples/gemm/cpp/README.md | 18 +-
.../examples/gemm/python/01_basic_gemm.py | 291 +--
.../examples/gemm/python/02_batch_gemm.py | 35 +-
.../examples/gemm/python/03_benchmark.py | 37 +-
.../examples/gemm/python/04_validation.py | 34 +-
.../gemm/python/05_numpy_integration.py | 6 +-
.../examples/gemm/python/06_json_export.py | 6 +-
.../examples/gemm/python/07_stress_test.py | 6 +-
.../examples/gemm/python/08_heuristics.py | 6 +-
.../examples/gemm/python/09_multi_registry.py | 6 +-
.../gemm/python/10_advanced_benchmark.py | 7 +-
.../examples/gemm/python/11_json_import.py | 16 +-
dispatcher/examples/gemm/python/README.md | 2 +-
.../cpp/01_basic_grouped_conv.cpp | 203 ++
.../grouped_conv/cpp/02_all_directions.cpp | 216 ++
.../cpp/03_benchmark_validation.cpp | 263 +++
.../grouped_conv/cpp/04_registry_json.cpp | 154 ++
.../examples/grouped_conv/cpp/05_bwd_data.cpp | 183 ++
.../grouped_conv/cpp/06_bwd_weight.cpp | 188 ++
.../cpp/07_multi_tile_benchmark.cpp | 226 +++
.../python/01_basic_grouped_conv.py | 271 +++
.../grouped_conv/python/02_forward.py | 222 ++
.../grouped_conv/python/03_bwd_data.py | 214 ++
.../grouped_conv/python/04_bwd_weight.py | 224 ++
.../grouped_conv/python/05_benchmark.py | 318 +++
.../grouped_conv/python/06_registry_json.py | 274 +++
dispatcher/include/ck_tile/dispatcher.hpp | 20 +-
.../include/ck_tile/dispatcher/README.md | 96 +-
.../backends/generated_conv_backend.hpp | 152 ++
.../ck_tile/dispatcher/base_registry.hpp | 199 ++
.../include/ck_tile/dispatcher/dispatcher.hpp | 22 +-
.../ck_tile/dispatcher/dispatcher_error.hpp | 28 +
.../ck_tile/dispatcher/dispatcher_log.hpp | 55 +
.../dispatcher/grouped_conv_config.hpp | 588 ++++++
.../dispatcher/grouped_conv_kernel_decl.hpp | 537 +++++
.../dispatcher/grouped_conv_problem.hpp | 255 +++
.../dispatcher/grouped_conv_registry.hpp | 614 ++++++
.../ck_tile/dispatcher/grouped_conv_utils.hpp | 324 +++
.../include/ck_tile/dispatcher/problem.hpp | 8 +-
.../include/ck_tile/dispatcher/registry.hpp | 105 +-
.../include/ck_tile/dispatcher_conv.hpp | 18 +
.../include/ck_tile/dispatcher_gemm.hpp | 22 +
dispatcher/python/CMakeLists.txt | 2 +-
dispatcher/python/README.md | 48 +-
dispatcher/python/ctypes_utils.py | 715 ++++++-
dispatcher/python/dispatcher_common.py | 372 ++++
dispatcher/python/grouped_conv_utils.py | 1806 +++++++++++++++++
dispatcher/scripts/compile_gemm_examples.py | 87 +-
.../scripts/compile_grouped_conv_examples.py | 882 ++++++++
dispatcher/scripts/example_kernel_builder.py | 396 ++--
.../scripts/generate_conv_dispatch_header.py | 107 +
dispatcher/scripts/parallel_kernel_builder.py | 2 +-
dispatcher/scripts/stress_test_autocorrect.py | 10 +-
dispatcher/src/dispatcher.cpp | 13 +-
dispatcher/src/registry.cpp | 181 +-
dispatcher/tests/CMakeLists.txt | 4 +
dispatcher/tests/test_autocorrect.py | 8 +-
dispatcher/tests/test_codegen_common.py | 244 +++
dispatcher/tests/test_dispatcher_common.py | 243 +++
dispatcher/tests/test_examples_integration.py | 175 +-
dispatcher/tests/test_grouped_conv_codegen.py | 589 ++++++
dispatcher/tests/test_grouped_conv_config.cpp | 112 +
.../tests/test_grouped_conv_kernel_decl.cpp | 141 ++
.../tests/test_grouped_conv_problem.cpp | 245 +++
.../tests/test_grouped_conv_registry.cpp | 230 +++
dispatcher/tests/test_grouped_conv_utils.py | 349 ++++
dispatcher/tests/test_problem_extended.cpp | 8 +-
.../tests/test_real_kernel_multi_size.cpp | 2 +-
.../tests/test_real_kernel_performance.cpp | 2 +-
86 files changed, 15538 insertions(+), 1500 deletions(-)
create mode 100644 dispatcher/codegen/codegen_common.py
create mode 100644 dispatcher/codegen/unified_grouped_conv_codegen.py
create mode 100644 dispatcher/examples/gemm/cpp/07_gfx950_minimal.cpp
create mode 100644 dispatcher/examples/grouped_conv/cpp/01_basic_grouped_conv.cpp
create mode 100644 dispatcher/examples/grouped_conv/cpp/02_all_directions.cpp
create mode 100644 dispatcher/examples/grouped_conv/cpp/03_benchmark_validation.cpp
create mode 100644 dispatcher/examples/grouped_conv/cpp/04_registry_json.cpp
create mode 100644 dispatcher/examples/grouped_conv/cpp/05_bwd_data.cpp
create mode 100644 dispatcher/examples/grouped_conv/cpp/06_bwd_weight.cpp
create mode 100644 dispatcher/examples/grouped_conv/cpp/07_multi_tile_benchmark.cpp
create mode 100644 dispatcher/examples/grouped_conv/python/01_basic_grouped_conv.py
create mode 100644 dispatcher/examples/grouped_conv/python/02_forward.py
create mode 100644 dispatcher/examples/grouped_conv/python/03_bwd_data.py
create mode 100644 dispatcher/examples/grouped_conv/python/04_bwd_weight.py
create mode 100644 dispatcher/examples/grouped_conv/python/05_benchmark.py
create mode 100644 dispatcher/examples/grouped_conv/python/06_registry_json.py
create mode 100644 dispatcher/include/ck_tile/dispatcher/backends/generated_conv_backend.hpp
create mode 100644 dispatcher/include/ck_tile/dispatcher/base_registry.hpp
create mode 100644 dispatcher/include/ck_tile/dispatcher/dispatcher_error.hpp
create mode 100644 dispatcher/include/ck_tile/dispatcher/dispatcher_log.hpp
create mode 100644 dispatcher/include/ck_tile/dispatcher/grouped_conv_config.hpp
create mode 100644 dispatcher/include/ck_tile/dispatcher/grouped_conv_kernel_decl.hpp
create mode 100644 dispatcher/include/ck_tile/dispatcher/grouped_conv_problem.hpp
create mode 100644 dispatcher/include/ck_tile/dispatcher/grouped_conv_registry.hpp
create mode 100644 dispatcher/include/ck_tile/dispatcher/grouped_conv_utils.hpp
create mode 100644 dispatcher/include/ck_tile/dispatcher_conv.hpp
create mode 100644 dispatcher/include/ck_tile/dispatcher_gemm.hpp
create mode 100644 dispatcher/python/dispatcher_common.py
create mode 100644 dispatcher/python/grouped_conv_utils.py
create mode 100644 dispatcher/scripts/compile_grouped_conv_examples.py
create mode 100644 dispatcher/scripts/generate_conv_dispatch_header.py
create mode 100644 dispatcher/tests/test_codegen_common.py
create mode 100644 dispatcher/tests/test_dispatcher_common.py
create mode 100644 dispatcher/tests/test_grouped_conv_codegen.py
create mode 100644 dispatcher/tests/test_grouped_conv_config.cpp
create mode 100644 dispatcher/tests/test_grouped_conv_kernel_decl.cpp
create mode 100644 dispatcher/tests/test_grouped_conv_problem.cpp
create mode 100644 dispatcher/tests/test_grouped_conv_registry.cpp
create mode 100644 dispatcher/tests/test_grouped_conv_utils.py
diff --git a/dispatcher/README.md b/dispatcher/README.md
index 1395285d60..dc864f7c62 100644
--- a/dispatcher/README.md
+++ b/dispatcher/README.md
@@ -1,6 +1,6 @@
# CK Tile Dispatcher
-A unified kernel dispatch system for AMD GPUs with C++ and Python frontends.
+A unified kernel dispatch system for AMD GPUs with C++ and Python frontends, supporting GEMM and Grouped Convolution operations.
**Validated Platform:** AMD Instinct MI300 series (gfx942)
@@ -342,8 +342,8 @@ ls examples/libdispatcher_gemm_lib.so
| `CMAKE_PREFIX_PATH` | - | ROCm installation path |
| `CMAKE_CXX_COMPILER` | - | Path to hipcc compiler |
-⚠️ **Important:** Always use `-DCMAKE_BUILD_TYPE=Release` for benchmarking. Debug builds are slower.
-⚠️ **Important:** Note that the current system provides single GPU target support for architecture-based kernel filtering, please do not use multiple GPU targets at a time (if necessary, please compile into different build directories).
+WARNING: **Important:** Always use `-DCMAKE_BUILD_TYPE=Release` for benchmarking. Debug builds are slower.
+WARNING: **Important:** Note that the current system provides single GPU target support for architecture-based kernel filtering, please do not use multiple GPU targets at a time (if necessary, please compile into different build directories).
---
@@ -363,6 +363,15 @@ cd build/examples
./gemm_04_heuristics # Heuristic kernel selection
./gemm_05_json_export # Registry JSON export
./gemm_06_multi_registry # Multiple registries
+
+# Grouped Convolution Examples
+./grouped_conv_01_basic # Declaration patterns + GPU execution
+./grouped_conv_02_all_dirs # Forward/BwdData/BwdWeight with GPU
+./grouped_conv_03_bench_val # Benchmark + CPU reference validation
+./grouped_conv_04_registry_json # Heuristic selection + JSON export
+./grouped_conv_05_bwd_data # Backward data + CPU validation
+./grouped_conv_06_bwd_weight # Backward weight + CPU validation
+./grouped_conv_07_benchmark # Multi-tile ResNet benchmark
```
### Python Examples
@@ -375,8 +384,16 @@ cd /path/to/composable_kernel/dispatcher
# GEMM Examples
python3 examples/gemm/python/01_basic_gemm.py # Basic multi-kernel GEMM
python3 examples/gemm/python/04_validation.py # CPU reference validation
-python3 examples/gemm/python/07_stress_test.py # Stress test (48 kernels)
+python3 examples/gemm/python/07_stress_test.py # Stress test
python3 examples/gemm/python/08_heuristics.py # Heuristic selection
+
+# Grouped Convolution Examples
+python3 examples/grouped_conv/python/01_basic_grouped_conv.py # Config patterns + registry + GPU
+python3 examples/grouped_conv/python/02_forward.py # Forward 2D/3D + CPU ref
+python3 examples/grouped_conv/python/03_bwd_data.py # Backward data + CPU ref
+python3 examples/grouped_conv/python/04_bwd_weight.py # Backward weight + CPU ref
+python3 examples/grouped_conv/python/05_benchmark.py # Multi-problem benchmark
+python3 examples/grouped_conv/python/06_registry_json.py # Heuristic selection + JSON
```
### Example Output
@@ -647,7 +664,7 @@ lib = DispatcherLib.load("/absolute/path/to/libdispatcher_gemm_lib.so")
### Data Flow
```
-KernelConfig → Registry → Dispatcher → GPU Execution
+KernelConfig -> Registry -> Dispatcher -> GPU Execution
```
1. **KernelConfig**: Defines kernel parameters (tile sizes, data types, layouts)
@@ -843,31 +860,49 @@ make -j$(nproc)
```
dispatcher/
-├── README.md # This file
-├── CMakeLists.txt # Build configuration
-│
-├── include/ck_tile/dispatcher/ # C++ headers
-│ ├── dispatcher.hpp # GEMM dispatcher
-│ ├── registry.hpp # Kernel registry
-│ └── kernel_key.hpp # Kernel configuration
-│
-├── src/ # C++ implementation
-│
-├── codegen/ # Kernel generation
-│ ├── unified_gemm_codegen.py # GEMM kernel generator
-│ └── arch_specs.json # GPU specifications
-│
-├── bindings/ctypes/ # Python ctypes interface
-│ └── gemm_ctypes_lib.cpp # GEMM Python library
-│
-├── examples/ # Examples
-│ └── gemm/
-│ ├── cpp/ # C++ GEMM examples (01-06)
-│ └── python/ # Python GEMM examples (01-11)
-│
-├── scripts/ # Build scripts
-│
-└── tests/ # Unit tests
+|---- README.md # This file
+|---- CMakeLists.txt # Build configuration
+|
+|---- include/ck_tile/dispatcher/ # C++ headers
+| |---- dispatcher.hpp # Main dispatcher include
+| |---- registry.hpp # GEMM kernel registry
+| |---- kernel_key.hpp # Kernel configuration
+| |---- grouped_conv_config.hpp # Grouped conv configuration
+| |---- grouped_conv_problem.hpp # Grouped conv problem (with builder)
+| |---- grouped_conv_kernel_decl.hpp # Grouped conv kernel declarations
+| |---- grouped_conv_registry.hpp # Grouped conv registry (thread-safe)
+| +---- grouped_conv_utils.hpp # Grouped conv utilities
+|
+|---- src/ # C++ implementation
+|
+|---- codegen/ # Kernel generation
+| |---- codegen_common.py # Shared: TileConfig, TraitConfigBase, type mappings
+| |---- unified_gemm_codegen.py # GEMM kernel generator
+| |---- unified_grouped_conv_codegen.py # Grouped conv kernel generator
+| +---- arch_specs.json # GPU specifications
+|
+|---- python/ # Python utilities
+| |---- dispatcher_common.py # Shared: paths, validation, Colors, phased output
+| |---- ctypes_utils.py # GEMM ctypes utilities
+| +---- grouped_conv_utils.py # Grouped conv utilities
+|
+|---- scripts/ # Build scripts
+| |---- compile_gemm_examples.py # GEMM build script
+| +---- compile_grouped_conv_examples.py # Grouped conv build script
+|
+|---- bindings/ctypes/ # Python ctypes interface
+| |---- gemm_ctypes_lib.cpp # GEMM Python library
+| +---- conv_ctypes_lib.cpp # Grouped conv Python library
+|
+|---- examples/ # Examples
+| |---- gemm/
+| | |---- cpp/ # C++ GEMM examples (01-07)
+| | +---- python/ # Python GEMM examples (01-11)
+| +---- grouped_conv/
+| |---- cpp/ # C++ Grouped Conv examples (01-07)
+| +---- python/ # Python Grouped Conv examples (01-06)
+|
++---- tests/ # Unit tests (C++ and Python)
```
---
@@ -879,17 +914,49 @@ dispatcher/
| GEMM C++ | [examples/gemm/cpp/README.md](examples/gemm/cpp/README.md) |
| GEMM Python | [examples/gemm/python/README.md](examples/gemm/python/README.md) |
| Codegen | [codegen/README.md](codegen/README.md) |
+| Python Utils | [python/README.md](python/README.md) |
+| C++ Headers | [include/ck_tile/dispatcher/README.md](include/ck_tile/dispatcher/README.md) |
---
-## Archived Content
+## Grouped Convolution Support
-Convolution examples and utilities have been archived to `ck-2/conv_archive/dispatcher/`:
-- `examples/conv/cpp/` - 11 C++ convolution examples
-- `examples/conv/python/` - 14 Python convolution examples
-- `codegen/unified_conv_codegen.py` - Conv kernel generator
-- `include/ck_tile/dispatcher/conv_*.hpp` - Conv headers
-- `python/conv_utils.py` - Conv Python utilities
+Grouped convolution is fully supported alongside GEMM, with shared infrastructure to eliminate duplication.
+
+### Python
+
+```bash
+# Generate grouped conv kernels
+python3 codegen/unified_grouped_conv_codegen.py \
+ --output-dir build/generated_kernels \
+ --datatype fp16 --variant forward --ndim-spatial 2
+
+# Build grouped conv examples
+python3 scripts/compile_grouped_conv_examples.py examples/grouped_conv/cpp/01_basic_grouped_conv.cpp
+```
+
+### Key Files
+
+| Component | File |
+|-----------|------|
+| C++ Headers | `include/ck_tile/dispatcher/grouped_conv_*.hpp` |
+| Python Codegen | `codegen/unified_grouped_conv_codegen.py` |
+| Python Utils | `python/grouped_conv_utils.py` |
+| Build Script | `scripts/compile_grouped_conv_examples.py` |
+| Shared Codegen | `codegen/codegen_common.py` |
+| Shared Utils | `python/dispatcher_common.py` |
+
+### Variants
+
+- **Forward** (`grouped_conv_fwd`) - Standard grouped convolution
+- **Backward Data** (`grouped_conv_bwd_data`) - Gradient w.r.t. input
+- **Backward Weight** (`grouped_conv_bwd_weight`) - Gradient w.r.t. weights
+
+### Shared Infrastructure
+
+GEMM and grouped convolution share common code to avoid duplication:
+- `codegen/codegen_common.py` - TileConfig, TraitConfigBase, type mappings, parallel generation, arch-aware expansion
+- `python/dispatcher_common.py` - Path helpers, validation, auto-correction, Colors, phased output
---
diff --git a/dispatcher/bindings/README.md b/dispatcher/bindings/README.md
index 7cda21f6ec..04029d32a9 100644
--- a/dispatcher/bindings/README.md
+++ b/dispatcher/bindings/README.md
@@ -6,13 +6,13 @@ This directory contains language bindings for the CK Tile Dispatcher.
```
bindings/
-├── ctypes/ # Python ctypes bindings (C API)
-│ ├── gemm_ctypes_lib.cpp # GEMM dispatcher C API
-│ ├── conv_ctypes_lib.cpp # Convolution dispatcher C API (fwd + bwd_data)
-│ ├── conv_bwdw_ctypes_lib.cpp # Convolution backward weight C API
-│ ├── gpu_helper.cpp # CLI helper for Python
-│ └── CMakeLists.txt
-└── README.md
+|---- ctypes/ # Python ctypes bindings (C API)
+| |---- gemm_ctypes_lib.cpp # GEMM dispatcher C API
+| |---- conv_ctypes_lib.cpp # Grouped conv dispatcher C API (fwd + bwd_data)
+| |---- conv_bwdw_ctypes_lib.cpp # Grouped conv backward weight C API (separate library)
+| |---- gpu_helper.cpp # CLI helper for Python
+| +---- CMakeLists.txt
++---- README.md
```
## ctypes Bindings
@@ -65,7 +65,7 @@ lib.dispatcher_cleanup()
| `dispatcher_export_registry_json()` | Export registry as JSON |
| `dispatcher_cleanup()` | Release resources |
-### Convolution API
+### Grouped Convolution API
| Function | Description |
|----------|-------------|
@@ -105,5 +105,11 @@ Output is JSON for easy parsing:
See the examples that use these bindings:
- **GEMM**: `dispatcher/examples/gemm/python/`
-- **Conv**: `dispatcher/examples/conv/python/`
+
+### Grouped Convolution
+
+Grouped convolution C++ headers and Python utilities are in:
+- **C++ Headers**: `dispatcher/include/ck_tile/dispatcher/grouped_conv_*.hpp`
+- **Python Utils**: `dispatcher/python/grouped_conv_utils.py`
+- **Build Script**: `dispatcher/scripts/compile_grouped_conv_examples.py`
diff --git a/dispatcher/bindings/ctypes/CMakeLists.txt b/dispatcher/bindings/ctypes/CMakeLists.txt
index 804e5e9bd7..18314017f2 100644
--- a/dispatcher/bindings/ctypes/CMakeLists.txt
+++ b/dispatcher/bindings/ctypes/CMakeLists.txt
@@ -78,7 +78,7 @@ endif()
# Look for forward kernels
file(GLOB CONV_FWD_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/conv_fwd_*.hpp")
# Look for backward data kernels
-file(GLOB CONV_BWDD_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/conv_bwdd_*.hpp")
+file(GLOB CONV_BWDD_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/conv_bwd_data_*.hpp")
# Fallback: any conv kernel (for backwards compatibility)
file(GLOB CONV_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/conv_*.hpp")
@@ -112,7 +112,7 @@ endif()
# Add backward data kernel if available
if(CONV_BWDD_KERNEL_HEADERS)
list(GET CONV_BWDD_KERNEL_HEADERS 0 CONV_BWDD_KERNEL_HEADER)
- message(STATUS "Found Conv BWD_DATA kernel for ctypes lib: ${CONV_BWDD_KERNEL_HEADER}")
+ message(STATUS "Found Conv BWD_DATA kernel for ctypes lib: ${CONV_BWD_DATA_KERNEL_HEADER}")
target_compile_options(dispatcher_conv_lib PRIVATE -include ${CONV_BWDD_KERNEL_HEADER})
target_compile_definitions(dispatcher_conv_lib PRIVATE CONV_BWD_DATA_AVAILABLE)
endif()
diff --git a/dispatcher/bindings/ctypes/conv_bwdw_ctypes_lib.cpp b/dispatcher/bindings/ctypes/conv_bwdw_ctypes_lib.cpp
index 09e058f80f..96b4aa3462 100644
--- a/dispatcher/bindings/ctypes/conv_bwdw_ctypes_lib.cpp
+++ b/dispatcher/bindings/ctypes/conv_bwdw_ctypes_lib.cpp
@@ -53,6 +53,7 @@ struct ConvBwdwProblemC
int stride_d, stride_h, stride_w;
int pad_d, pad_h, pad_w;
int dilation_d, dilation_h, dilation_w;
+ int split_k;
};
// =============================================================================
@@ -108,8 +109,7 @@ static float run_bwd_weight_impl(const void* input_ptr,
grad_weight_ptr, // wei_ptr = grad_weight (output)
{}, // ds_ptr
grad_output_ptr, // out_ptr = grad_output
- 1 // k_batch
- );
+ (prob->split_k > 1) ? prob->split_k : 1);
ck_tile::stream_config stream_cfg{static_cast(stream), true, 1, 3, 10};
diff --git a/dispatcher/bindings/ctypes/conv_ctypes_lib.cpp b/dispatcher/bindings/ctypes/conv_ctypes_lib.cpp
index d3c64621a7..002219c82e 100644
--- a/dispatcher/bindings/ctypes/conv_ctypes_lib.cpp
+++ b/dispatcher/bindings/ctypes/conv_ctypes_lib.cpp
@@ -1,128 +1,46 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
-
-/**
- * Convolution Dispatcher ctypes Library
- *
- * Provides C API for Python ctypes integration.
- * Supports forward convolution. Backward operations require additional headers.
- *
- * REQUIRED: Forward kernel header must be force-included via -include flag.
- * OPTIONAL: Backward kernels can be added with CONV_BWD_DATA_AVAILABLE/CONV_BWD_WEIGHT_AVAILABLE
- *
- * Usage from Python:
- * lib = ctypes.CDLL("libdispatcher_conv.so")
- * lib.conv_dispatcher_init()
- * lib.conv_dispatcher_run(...)
- */
+//
+// Multi-kernel grouped convolution dispatcher for Python ctypes.
+//
+// Supports: forward / backward-data / backward-weight x 2D / 3D
+//
+// The dispatch header (conv_python_dispatch.hpp) is force-included via
+// -include and brings in ALL compiled kernels with these aliases:
+//
+// 2D launchers (from include_all headers):
+// SelectedConvKernelLauncher (forward 2D)
+// SelectedConvBwdDataLauncher (backward-data 2D)
+// SelectedConvBwdWeightLauncher (backward-weight 2D)
+//
+// 3D launchers (from dispatch header):
+// ConvFwd3dLauncher (forward 3D)
+// ConvBwdData3dLauncher (backward-data 3D)
+// ConvBwdWeight3dLauncher (backward-weight 3D)
+//
+// Usage from Python:
+// lib = ctypes.CDLL("libdispatcher_conv_lib.so")
+// lib.conv_dispatcher_init()
+// lib.conv_dispatcher_run(input, weight, output, &problem, stream)
#include
-#include
-#include
+#include
#include
-#include "ck_tile/dispatcher/conv_utils.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
-using namespace ck_tile::dispatcher;
-
-// Global state (using shared_ptr for safe memory management)
-static std::shared_ptr g_registry = nullptr;
-static std::shared_ptr g_dispatcher = nullptr;
-static std::vector g_kernels;
-
extern "C" {
-// =============================================================================
-// Initialization
-// =============================================================================
-
-int conv_dispatcher_init()
+// =========================================================================
+// Problem definition (matches Python ctypes struct exactly)
+// =========================================================================
+enum ConvDirection
{
- if(g_registry)
- return 0; // Already initialized
-
- g_registry = std::make_shared();
- g_dispatcher = std::make_shared(g_registry.get());
-
- // Register kernel configurations using simple ConvKernelSet
- // (actual kernel launch uses the force-included SelectedConvKernelLauncher)
- using namespace ck_tile::dispatcher::conv_decl;
-
- // Forward kernels (required - must be force-included)
- // Must match: conv_fwd_fp16_nhwgc_2d_compv4_cshuffle_intrawave_128x128x64_2x2x1_32x32x16_dsb
- ConvKernelSet fwd_set;
- fwd_set.add(ConvSignature().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2),
- ConvAlgorithm()
- .tile(128, 128, 64) // tile_m x tile_n x tile_k
- .wave(2, 2, 1)
- .warp(32, 32, 16)
- .pipeline("compv4")
- .scheduler("intrawave"),
- "gfx942");
- g_registry->register_set(fwd_set, ConvRegistry::Priority::High);
-
-#ifdef CONV_BWD_DATA_AVAILABLE
- // Backward data kernels
- // Must match: conv_bwdd_fp16_nhwgc_2d_compv3_cshuffle_intrawave_128x128x64_2x2x1_32x32x16
- ConvKernelSet bwd_data_set;
- bwd_data_set.add(ConvSignature().dtype("fp16").layout("nhwgc").conv_type("bwd_data").dims(2),
- ConvAlgorithm()
- .tile(128, 128, 64) // tile_m x tile_n x tile_k
- .wave(2, 2, 1)
- .warp(32, 32, 16)
- .pipeline("compv3")
- .scheduler("intrawave"),
- "gfx942");
- g_registry->register_set(bwd_data_set, ConvRegistry::Priority::High);
-#endif
-
- return 0;
-}
-
-int conv_dispatcher_cleanup()
-{
- // shared_ptr automatically handles cleanup when reset
- g_dispatcher.reset();
- g_registry.reset();
- g_kernels.clear();
- return 0;
-}
-
-// =============================================================================
-// Registry Management
-// =============================================================================
-
-int conv_dispatcher_get_kernel_count()
-{
- if(!g_registry)
- return 0;
- return static_cast(g_registry->size());
-}
-
-int conv_dispatcher_get_kernel_name(int index, char* buffer, int buffer_size)
-{
- if(index < 0 || !buffer || buffer_size <= 0)
- return -1;
-
- if(!g_registry)
- return -1;
-
- // Use registry to get kernel names (they are registered with full names)
- const auto& kernels = g_registry->all_kernels();
- if(static_cast(index) >= kernels.size())
- return -1;
-
- const auto* kernel = kernels[index];
- std::strncpy(buffer, kernel->name().c_str(), buffer_size - 1);
- buffer[buffer_size - 1] = '\0';
- return 0;
-}
-
-// =============================================================================
-// Problem Definition
-// =============================================================================
+ CONV_FORWARD = 0,
+ CONV_BWD_DATA = 1,
+ CONV_BWD_WEIGHT = 2
+};
struct ConvProblemC
{
@@ -132,267 +50,33 @@ struct ConvProblemC
int stride_d, stride_h, stride_w;
int pad_d, pad_h, pad_w;
int dilation_d, dilation_h, dilation_w;
- int direction; // 0=forward, 1=bwd_data, 2=bwd_weight
+ int direction;
+ int split_k;
};
-// =============================================================================
-// Kernel Selection
-// =============================================================================
+// =========================================================================
+// Initialization / lifecycle
+// =========================================================================
+int conv_dispatcher_init() { return 0; }
+int conv_dispatcher_cleanup() { return 0; }
-int conv_dispatcher_is_supported(const ConvProblemC* prob)
-{
- if(!g_registry || !prob)
- return 0;
-
- ConvProblem problem;
- problem.N = prob->N;
- problem.G = prob->G;
- problem.C = prob->C;
- problem.K = prob->K;
- problem.input_spatial = {prob->input_d, prob->input_h, prob->input_w};
- problem.filter_spatial = {prob->filter_z, prob->filter_y, prob->filter_x};
- problem.stride = {prob->stride_d, prob->stride_h, prob->stride_w};
- problem.padding = {prob->pad_d, prob->pad_h, prob->pad_w};
- problem.dilation = {prob->dilation_d, prob->dilation_h, prob->dilation_w};
- problem.op = static_cast(prob->direction);
- problem.compute_output_size();
-
- const auto* kernel = g_dispatcher->select(problem);
- return kernel ? 1 : 0;
-}
-
-int conv_dispatcher_select_kernel(const ConvProblemC* prob, char* kernel_name, int buffer_size)
-{
- if(!g_registry || !prob || !kernel_name || buffer_size <= 0)
- return -1;
-
- ConvProblem problem;
- problem.N = prob->N;
- problem.G = prob->G;
- problem.C = prob->C;
- problem.K = prob->K;
- problem.input_spatial = {prob->input_d, prob->input_h, prob->input_w};
- problem.filter_spatial = {prob->filter_z, prob->filter_y, prob->filter_x};
- problem.stride = {prob->stride_d, prob->stride_h, prob->stride_w};
- problem.padding = {prob->pad_d, prob->pad_h, prob->pad_w};
- problem.dilation = {prob->dilation_d, prob->dilation_h, prob->dilation_w};
- problem.op = static_cast(prob->direction);
- problem.compute_output_size();
-
- const auto* kernel = g_dispatcher->select(problem);
- if(!kernel)
- return -1;
-
- std::strncpy(kernel_name, kernel->name().c_str(), buffer_size - 1);
- kernel_name[buffer_size - 1] = '\0';
-
- return 0;
-}
-
-// =============================================================================
-// Convolution Execution
-// =============================================================================
-
-// Helper to build ConvParam
-static ck_tile::conv::ConvParam build_conv_param(const ConvProblemC* prob)
-{
- // Determine if this is 2D or 3D convolution
- const bool is_3d = (prob->input_d > 1 || prob->filter_z > 1);
-
- if(is_3d)
- {
- // 3D convolution: use all spatial dimensions
- return ck_tile::conv::ConvParam{3,
- prob->G,
- prob->N,
- prob->K,
- prob->C,
- {prob->filter_z, prob->filter_y, prob->filter_x},
- {prob->input_d, prob->input_h, prob->input_w},
- {prob->stride_d, prob->stride_h, prob->stride_w},
- {prob->dilation_d, prob->dilation_h, prob->dilation_w},
- {prob->pad_d, prob->pad_h, prob->pad_w},
- {prob->pad_d, prob->pad_h, prob->pad_w}};
- }
- else
- {
- // 2D convolution: only use H, W dimensions
- return ck_tile::conv::ConvParam{2,
- prob->G,
- prob->N,
- prob->K,
- prob->C,
- {prob->filter_y, prob->filter_x},
- {prob->input_h, prob->input_w},
- {prob->stride_h, prob->stride_w},
- {prob->dilation_h, prob->dilation_w},
- {prob->pad_h, prob->pad_w},
- {prob->pad_h, prob->pad_w}};
- }
-}
-
-// Forward convolution (required - kernel header must be force-included)
-static float run_forward(const void* input_ptr,
- const void* weight_ptr,
- void* output_ptr,
- const ConvProblemC* prob,
- void* stream)
-{
- auto conv_param = build_conv_param(prob);
-
- ck_tile::GroupedConvFwdHostArgs<> args(conv_param, input_ptr, weight_ptr, {}, output_ptr, 1);
-
- ck_tile::stream_config stream_cfg{static_cast(stream), true, 1, 3, 10};
-
- // SelectedConvKernelLauncher is defined in the force-included forward kernel header
- return SelectedConvKernelLauncher::launch(args, stream_cfg);
-}
-
-#ifdef CONV_BWD_DATA_AVAILABLE
-// Backward data convolution (optional)
-// Computes: grad_input = conv_bwd_data(weight, grad_output)
-//
-// Parameters:
-// grad_output_ptr: dY - gradient from next layer (const, read-only INPUT)
-// weight_ptr: W - frozen weights (const, read-only INPUT)
-// grad_input_ptr: dX - gradient for input (writable, OUTPUT)
-static float run_bwd_data(const void* grad_output_ptr,
- const void* weight_ptr,
- void* grad_input_ptr,
- const ConvProblemC* prob,
- void* stream)
-{
- auto conv_param = build_conv_param(prob);
-
- // CK Tile API uses tensor POSITION names (from forward pass), not data flow:
- // in_ptr = input tensor position = grad_input_ptr (dX, OUTPUT of bwd_data)
- // wei_ptr = weight tensor = weight_ptr (W, const)
- // out_ptr = output tensor position = grad_output_ptr (dY, INPUT to bwd_data)
- ck_tile::GroupedConvBwdDataHostArgs args(
- conv_param, grad_input_ptr, weight_ptr, {}, grad_output_ptr, 1);
-
- ck_tile::stream_config stream_cfg{static_cast(stream), true, 1, 3, 10};
-
- return SelectedConvBwdDataLauncher::launch(args, stream_cfg);
-}
-#endif
-
-#ifdef CONV_BWD_WEIGHT_AVAILABLE
-// Backward weight convolution (optional)
-// Parameters:
-// input_ptr: original forward input X (const, read-only)
-// grad_output_ptr: gradient from next layer dY (const, read-only)
-// grad_weight_ptr: gradient of weights dW (writable, OUTPUT)
-static float run_bwd_weight(const void* input_ptr,
- const void* grad_output_ptr,
- void* grad_weight_ptr,
- const ConvProblemC* prob,
- void* stream)
-{
- auto conv_param = build_conv_param(prob);
-
- // GroupedConvBwdWeightHostArgs constructor order:
- // (param, in=X, wei=dW (output), ds, out=dY (input), k_batch)
- // Note: wei_ptr is the OUTPUT (grad_weight), out_ptr is the INPUT (grad_output)
- ck_tile::GroupedConvBwdWeightHostArgs args(
- conv_param, input_ptr, grad_weight_ptr, {}, grad_output_ptr, 1);
-
- ck_tile::stream_config stream_cfg{static_cast(stream), true, 1, 3, 10};
-
- return SelectedConvBwdWeightLauncher::launch(args, stream_cfg);
-}
-#endif
-
-/**
- * @brief Execute convolution based on direction specified in prob
- *
- * Parameter mapping varies by direction:
- * Forward (direction=0):
- * input_ptr = X (input tensor)
- * weight_ptr = W (weight tensor)
- * output_ptr = Y (output buffer)
- *
- * Backward Data (direction=1):
- * input_ptr = dY (grad_output - gradient from next layer)
- * weight_ptr = W (weight tensor, frozen)
- * output_ptr = dX (grad_input buffer)
- *
- * Backward Weight (direction=2):
- * input_ptr = X (forward input tensor)
- * weight_ptr = dY (grad_output - gradient from next layer)
- * output_ptr = dW (grad_weight buffer)
- */
-float conv_dispatcher_run(const void* input_ptr,
- const void* weight_ptr,
- void* output_ptr,
- const ConvProblemC* prob,
- void* stream)
-{
- // Validate all required pointers before kernel launch
- if(!g_dispatcher || !prob)
- return -1.0f;
- if(!input_ptr || !weight_ptr || !output_ptr)
- return -1.0f; // Null data pointer would cause kernel crash
-
- // Build problem for kernel selection
- ConvProblem problem;
- problem.N = prob->N;
- problem.G = prob->G;
- problem.C = prob->C;
- problem.K = prob->K;
- problem.input_spatial = {prob->input_d, prob->input_h, prob->input_w};
- problem.filter_spatial = {prob->filter_z, prob->filter_y, prob->filter_x};
- problem.stride = {prob->stride_d, prob->stride_h, prob->stride_w};
- problem.padding = {prob->pad_d, prob->pad_h, prob->pad_w};
- problem.dilation = {prob->dilation_d, prob->dilation_h, prob->dilation_w};
- problem.op = static_cast(prob->direction);
- problem.compute_output_size();
-
- // Select kernel
- const auto* kernel = g_dispatcher->select(problem);
- if(!kernel)
- return -1.0f;
-
- // Dispatch based on direction
- switch(prob->direction)
- {
- case 0: // Forward (always available)
- return run_forward(input_ptr, weight_ptr, output_ptr, prob, stream);
-
-#ifdef CONV_BWD_DATA_AVAILABLE
- case 1: // Backward data
- // Convention: caller passes (grad_output, weight, grad_input_buffer)
- // in the (input_ptr, weight_ptr, output_ptr) slots respectively.
- // run_bwd_data expects: (grad_output, weight, grad_input)
- return run_bwd_data(input_ptr, weight_ptr, output_ptr, prob, stream);
-#endif
-
-#ifdef CONV_BWD_WEIGHT_AVAILABLE
- case 2: // Backward weight
- // Convention: caller passes (input, grad_output, grad_weight_buffer)
- // in the (input_ptr, weight_ptr, output_ptr) slots respectively.
- // run_bwd_weight expects: (input, grad_output, grad_weight)
- return run_bwd_weight(input_ptr, weight_ptr, output_ptr, prob, stream);
-#endif
-
- default: return -1.0f;
- }
-}
-
-// =============================================================================
-// Info
-// =============================================================================
-
-const char* conv_dispatcher_version() { return "1.0.0"; }
+// =========================================================================
+// Library info
+// =========================================================================
+const char* conv_dispatcher_version() { return "2.0.0"; }
int conv_dispatcher_has_kernels()
{
- return 1; // Forward kernel is required
+#if defined(CONV_FWD_2D_AVAILABLE) || defined(CONV_FWD_3D_AVAILABLE)
+ return 1;
+#else
+ return 0;
+#endif
}
int conv_dispatcher_has_bwd_data()
{
-#ifdef CONV_BWD_DATA_AVAILABLE
+#if defined(CONV_BWD_DATA_2D_AVAILABLE) || defined(CONV_BWD_DATA_3D_AVAILABLE)
return 1;
#else
return 0;
@@ -401,11 +85,240 @@ int conv_dispatcher_has_bwd_data()
int conv_dispatcher_has_bwd_weight()
{
-#ifdef CONV_BWD_WEIGHT_AVAILABLE
+#if defined(CONV_BWD_WEIGHT_2D_AVAILABLE) || defined(CONV_BWD_WEIGHT_3D_AVAILABLE)
return 1;
#else
return 0;
#endif
}
+int conv_dispatcher_get_kernel_count()
+{
+ return CONV_KERNEL_COUNT; // defined in conv_python_dispatch.hpp
+}
+
+int conv_dispatcher_get_kernel_name(int index, char* buffer, int buffer_size)
+{
+ if(!buffer || buffer_size <= 0 || index < 0 || index >= CONV_KERNEL_COUNT)
+ return -1;
+ std::strncpy(buffer, CONV_KERNEL_NAMES[index], buffer_size - 1);
+ buffer[buffer_size - 1] = '\0';
+ return 0;
+}
+
+// =========================================================================
+// Support query
+// =========================================================================
+bool conv_dispatcher_is_supported(const ConvProblemC* prob)
+{
+ if(!prob)
+ return false;
+ const bool is_3d = (prob->input_d > 1 || prob->filter_z > 1);
+ switch(prob->direction)
+ {
+ case CONV_FORWARD:
+#if defined(CONV_FWD_3D_AVAILABLE)
+ if(is_3d)
+ return true;
+#endif
+#if defined(CONV_FWD_2D_AVAILABLE)
+ if(!is_3d)
+ return true;
+#endif
+ return false;
+ case CONV_BWD_DATA:
+#if defined(CONV_BWD_DATA_3D_AVAILABLE)
+ if(is_3d)
+ return true;
+#endif
+#if defined(CONV_BWD_DATA_2D_AVAILABLE)
+ if(!is_3d)
+ return true;
+#endif
+ return false;
+ case CONV_BWD_WEIGHT:
+#if defined(CONV_BWD_WEIGHT_3D_AVAILABLE)
+ if(is_3d)
+ return true;
+#endif
+#if defined(CONV_BWD_WEIGHT_2D_AVAILABLE)
+ if(!is_3d)
+ return true;
+#endif
+ return false;
+ default: return false;
+ }
+}
+
+// =========================================================================
+// ConvParam builders
+// =========================================================================
+static ck_tile::conv::ConvParam make_param_2d(const ConvProblemC* p)
+{
+ return ck_tile::conv::ConvParam{2,
+ p->G,
+ p->N,
+ p->K,
+ p->C,
+ {p->filter_y, p->filter_x},
+ {p->input_h, p->input_w},
+ {p->stride_h, p->stride_w},
+ {p->dilation_h, p->dilation_w},
+ {p->pad_h, p->pad_w},
+ {p->pad_h, p->pad_w}};
+}
+
+static ck_tile::conv::ConvParam make_param_3d(const ConvProblemC* p)
+{
+ return ck_tile::conv::ConvParam{3,
+ p->G,
+ p->N,
+ p->K,
+ p->C,
+ {p->filter_z, p->filter_y, p->filter_x},
+ {p->input_d, p->input_h, p->input_w},
+ {p->stride_d, p->stride_h, p->stride_w},
+ {p->dilation_d, p->dilation_h, p->dilation_w},
+ {p->pad_d, p->pad_h, p->pad_w},
+ {p->pad_d, p->pad_h, p->pad_w}};
+}
+
+// =========================================================================
+// Kernel launch helpers
+// =========================================================================
+
+#ifdef CONV_FWD_2D_AVAILABLE
+static float
+launch_fwd_2d(const void* in, const void* wei, void* out, const ConvProblemC* p, hipStream_t stream)
+{
+ auto param = make_param_2d(p);
+ ck_tile::GroupedConvFwdHostArgs<> args(param, in, wei, {}, out, 1);
+ ck_tile::stream_config sc{stream, true, 1, 3, 10};
+ return SelectedConvKernelLauncher::launch(args, sc);
+}
+#endif
+
+#ifdef CONV_FWD_3D_AVAILABLE
+static float
+launch_fwd_3d(const void* in, const void* wei, void* out, const ConvProblemC* p, hipStream_t stream)
+{
+ auto param = make_param_3d(p);
+ ck_tile::GroupedConvFwdHostArgs<> args(param, in, wei, {}, out, 1);
+ ck_tile::stream_config sc{stream, true, 1, 3, 10};
+ return ConvFwd3dLauncher::launch(args, sc);
+}
+#endif
+
+#ifdef CONV_BWD_DATA_2D_AVAILABLE
+static float launch_bwd_data_2d(
+ const void* dy, const void* wei, void* dx, const ConvProblemC* p, hipStream_t stream)
+{
+ auto param = make_param_2d(p);
+ ck_tile::GroupedConvBwdDataHostArgs args(param, dx, wei, {}, dy, 1);
+ ck_tile::stream_config sc{stream, true, 1, 3, 10};
+ return SelectedConvBwdDataLauncher::launch(args, sc);
+}
+#endif
+
+#ifdef CONV_BWD_DATA_3D_AVAILABLE
+static float launch_bwd_data_3d(
+ const void* dy, const void* wei, void* dx, const ConvProblemC* p, hipStream_t stream)
+{
+ auto param = make_param_3d(p);
+ ck_tile::GroupedConvBwdDataHostArgs args(param, dx, wei, {}, dy, 1);
+ ck_tile::stream_config sc{stream, true, 1, 3, 10};
+ return ConvBwdData3dLauncher::launch(args, sc);
+}
+#endif
+
+#ifdef CONV_BWD_WEIGHT_2D_AVAILABLE
+static float launch_bwd_weight_2d(
+ const void* x, const void* dy, void* dw, const ConvProblemC* p, hipStream_t stream)
+{
+ auto param = make_param_2d(p);
+ const int k_batch = (p->split_k > 1) ? p->split_k : 1;
+ ck_tile::GroupedConvBwdWeightHostArgs args(param, x, dw, {}, dy, k_batch);
+ ck_tile::stream_config sc{stream, true, 1, 3, 10};
+ return SelectedConvBwdWeightLauncher::launch(args, sc);
+}
+#endif
+
+#ifdef CONV_BWD_WEIGHT_3D_AVAILABLE
+static float launch_bwd_weight_3d(
+ const void* x, const void* dy, void* dw, const ConvProblemC* p, hipStream_t stream)
+{
+ auto param = make_param_3d(p);
+ const int k_batch = (p->split_k > 1) ? p->split_k : 1;
+ ck_tile::GroupedConvBwdWeightHostArgs args(param, x, dw, {}, dy, k_batch);
+ ck_tile::stream_config sc{stream, true, 1, 3, 10};
+ return ConvBwdWeight3dLauncher::launch(args, sc);
+}
+#endif
+
+// =========================================================================
+// Main dispatch
+//
+// direction=0 (forward): a=X(input), b=W(weight), c=Y(output)
+// direction=1 (bwd_data): a=dY(grad_out), b=W(weight), c=dX(grad_in)
+// direction=2 (bwd_weight): a=X(input), b=dY(grad_out), c=dW(grad_wei)
+// =========================================================================
+float conv_dispatcher_run(
+ const void* a_ptr, const void* b_ptr, void* c_ptr, const ConvProblemC* prob, void* stream)
+{
+ if(!prob || !a_ptr || !b_ptr || !c_ptr)
+ return -1.0f;
+
+ const bool is_3d = (prob->input_d > 1 || prob->filter_z > 1);
+ hipStream_t hip_stream = static_cast(stream);
+
+ try
+ {
+ switch(prob->direction)
+ {
+ case CONV_FORWARD:
+#ifdef CONV_FWD_3D_AVAILABLE
+ if(is_3d)
+ return launch_fwd_3d(a_ptr, b_ptr, c_ptr, prob, hip_stream);
+#endif
+#ifdef CONV_FWD_2D_AVAILABLE
+ if(!is_3d)
+ return launch_fwd_2d(a_ptr, b_ptr, c_ptr, prob, hip_stream);
+#endif
+ return -2.0f;
+
+ case CONV_BWD_DATA:
+#ifdef CONV_BWD_DATA_3D_AVAILABLE
+ if(is_3d)
+ return launch_bwd_data_3d(a_ptr, b_ptr, c_ptr, prob, hip_stream);
+#endif
+#ifdef CONV_BWD_DATA_2D_AVAILABLE
+ if(!is_3d)
+ return launch_bwd_data_2d(a_ptr, b_ptr, c_ptr, prob, hip_stream);
+#endif
+ return -2.0f;
+
+ case CONV_BWD_WEIGHT:
+#ifdef CONV_BWD_WEIGHT_3D_AVAILABLE
+ if(is_3d)
+ return launch_bwd_weight_3d(a_ptr, b_ptr, c_ptr, prob, hip_stream);
+#endif
+#ifdef CONV_BWD_WEIGHT_2D_AVAILABLE
+ if(!is_3d)
+ return launch_bwd_weight_2d(a_ptr, b_ptr, c_ptr, prob, hip_stream);
+#endif
+ return -2.0f;
+
+ default: return -1.0f;
+ }
+ }
+ catch(const std::exception&)
+ {
+ return -3.0f; // Kernel rejected args (e.g. unsupported tile/channel combo)
+ }
+ catch(...)
+ {
+ return -3.0f;
+ }
+}
+
} // extern "C"
diff --git a/dispatcher/codegen/ADDING_NEW_GPU.md b/dispatcher/codegen/ADDING_NEW_GPU.md
index 0bd2966a85..664b59b6b1 100644
--- a/dispatcher/codegen/ADDING_NEW_GPU.md
+++ b/dispatcher/codegen/ADDING_NEW_GPU.md
@@ -9,8 +9,8 @@ Guide for adding support for a new AMD GPU architecture to the CK Tile Dispatche
The dispatcher uses `arch_specs.json` as the **single source of truth** for GPU specifications:
```
-arch_specs.json → generate_arch_specs.py → arch_specs_generated.py (Python)
- → arch_specs_generated.hpp (C++)
+arch_specs.json -> generate_arch_specs.py -> arch_specs_generated.py (Python)
+ -> arch_specs_generated.hpp (C++)
```
## Quick Start
@@ -175,14 +175,14 @@ for error in result.errors:
```
codegen/
-├── arch_specs.json # Single source of truth (EDIT THIS)
-├── generate_arch_specs.py # Generator script
-├── arch_specs_generated.py # Generated Python module
-└── ADDING_NEW_GPU.md # This file
+|---- arch_specs.json # Single source of truth (EDIT THIS)
+|---- generate_arch_specs.py # Generator script
+|---- arch_specs_generated.py # Generated Python module
++---- ADDING_NEW_GPU.md # This file
include/ck_tile/dispatcher/
-├── arch_specs_generated.hpp # Generated C++ header
-└── arch_filter.hpp # C++ filter
+|---- arch_specs_generated.hpp # Generated C++ header
++---- arch_filter.hpp # C++ filter
```
## Best Practices
diff --git a/dispatcher/codegen/README.md b/dispatcher/codegen/README.md
index 2d753924f5..40a9b7b8c1 100644
--- a/dispatcher/codegen/README.md
+++ b/dispatcher/codegen/README.md
@@ -1,11 +1,22 @@
-# CK Tile GEMM Unified Code Generator
+# CK Tile Unified Code Generators
-Single source of truth for all GEMM kernel generation.
+Single source of truth for GEMM and Grouped Convolution kernel generation.
> **See also:** [Main Dispatcher README](../README.md) for installation and core concepts.
+## Shared Infrastructure
+
+Both GEMM and Grouped Conv generators share common code via `codegen_common.py`:
+- `TileConfig` - Dataclass for tile dimensions
+- `TraitConfigBase` - Base for kernel trait configurations with arch-aware validation
+- `CommonTypeMappings` - Dtype-to-C++ type mappings
+- `parallel_generate()` - Parallel kernel generation with per-kernel progress logging
+- Arch-aware expansion helpers (`valid_wave_configs`, `valid_warp_configs`, etc.)
+
## Quick Start
+### GEMM
+
```bash
cd dispatcher/codegen
@@ -22,6 +33,25 @@ python3 unified_gemm_codegen.py \
--variants standard preshuffle multi_d
```
+### Grouped Convolution
+
+```bash
+cd dispatcher/codegen
+
+# Generate forward FP16 grouped conv kernels
+python3 unified_grouped_conv_codegen.py \
+ --output-dir ../build/generated_kernels \
+ --datatype fp16 \
+ --variant forward \
+ --ndim-spatial 2
+
+# Generate backward data kernels
+python3 unified_grouped_conv_codegen.py \
+ --output-dir ../build/generated_kernels \
+ --variant backward_data \
+ --ndim-spatial 2
+```
+
## Using from Python
```python
@@ -58,13 +88,13 @@ results = codegen.generate_all()
## Variants
### Standard
-Basic GEMM: `C = A × B`
+Basic GEMM: `C = A x B`
### PreShuffle
Optimized weight access with LDS pre-shuffling. Best for large matrices.
### Multi-D
-Element-wise fusion: `C = op(A × B + D0 + D1 + ...)`
+Element-wise fusion: `C = op(A x B + D0 + D1 + ...)`
Supported ops: `PassThrough`, `MultiDAdd`, `Relu`, `Gelu`, `Sigmoid`, `Tanh`
@@ -72,10 +102,11 @@ Supported ops: `PassThrough`, `MultiDAdd`, `Relu`, `Gelu`, `Sigmoid`, `Tanh`
```
generated_kernels/
-├── gemm_fp16_rcr_compv4_..._128x128x32_....hpp
-├── gemm_fp16_rcr_compv4_..._preshuffle.hpp
-├── gemm_fp16_rcr_compv4_..._multid_Relu_d1.hpp
-└── ...
+|---- gemm_fp16_rcr_compv4_..._128x128x32_....hpp # GEMM kernels
+|---- gemm_fp16_rcr_compv4_..._preshuffle.hpp
+|---- gemm_fp16_rcr_compv4_..._multid_Relu_d1.hpp
+|---- grouped_conv_fwd_fp16_nhwgc_..._128x128x32_....hpp # Grouped conv kernels
++---- ...
```
## Configuration Files
diff --git a/dispatcher/codegen/codegen_common.py b/dispatcher/codegen/codegen_common.py
new file mode 100644
index 0000000000..4e9e8de1b3
--- /dev/null
+++ b/dispatcher/codegen/codegen_common.py
@@ -0,0 +1,350 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
+# SPDX-License-Identifier: MIT
+
+"""
+Shared codegen infrastructure for GEMM and grouped convolution code generators.
+
+Extracted from unified_gemm_codegen.py + arch-aware expansion helpers from conv.
+Both unified_gemm_codegen.py and unified_grouped_conv_codegen.py import from here
+to eliminate duplication.
+"""
+
+import logging
+import concurrent.futures
+from dataclasses import dataclass
+from typing import (
+ Callable,
+ ClassVar,
+ Dict,
+ FrozenSet,
+ List,
+ Optional,
+ Sequence,
+ Tuple,
+ TypeVar,
+)
+
+log = logging.getLogger(__name__)
+
+T = TypeVar("T")
+R = TypeVar("R")
+
+ANY_INT = -1
+
+
+# ============================================================================
+# Tile and Trait Configuration (shared between GEMM and Conv)
+# ============================================================================
+
+
+@dataclass
+class TileConfig:
+ """Tile configuration parameters shared by GEMM and grouped conv."""
+
+ tile_m: int
+ tile_n: int
+ tile_k: int
+ warp_m: int
+ warp_n: int
+ warp_k: int
+ warp_tile_m: int
+ warp_tile_n: int
+ warp_tile_k: int
+
+ def is_valid(self) -> bool:
+ if self.tile_m <= 0 or self.tile_n <= 0 or self.tile_k <= 0:
+ return False
+ return (
+ self.tile_m % (self.warp_m * self.warp_tile_m) == 0
+ and self.tile_n % (self.warp_n * self.warp_tile_n) == 0
+ and self.tile_k % (self.warp_k * self.warp_tile_k) == 0
+ )
+
+
+@dataclass
+class TraitConfigBase:
+ """
+ Base kernel trait configuration shared by GEMM and grouped conv.
+
+ GEMM extends this with ``persistent``; grouped conv extends with
+ ``double_smem_buffer`` and ``num_groups_to_merge``.
+ """
+
+ pipeline: str # mem, compv3, compv4, compv5, ...
+ epilogue: str # cshuffle, default
+ scheduler: str # intrawave, interwave
+ pad_m: bool
+ pad_n: bool
+ pad_k: bool
+
+ # Unsupported (pipeline, epilogue, scheduler) combinations.
+ # Only 'mem' and 'basic_v1' pipelines support interwave; all compute
+ # pipelines (compv3/v4/v5/v6/async) only support intrawave.
+ _UNSUPPORTED: ClassVar[FrozenSet] = frozenset(
+ {
+ ("compv3", "cshuffle", "interwave"),
+ ("compv3", "default", "interwave"),
+ ("compv4", "cshuffle", "interwave"),
+ ("compv4", "default", "interwave"),
+ ("compv5", "cshuffle", "interwave"),
+ ("compv5", "default", "interwave"),
+ ("compv6", "cshuffle", "interwave"),
+ ("compv6", "default", "interwave"),
+ ("comp_async", "cshuffle", "interwave"),
+ ("comp_async", "default", "interwave"),
+ ("basic_async_v1", "cshuffle", "interwave"),
+ ("basic_async_v1", "default", "interwave"),
+ }
+ )
+
+ def is_valid(self) -> bool:
+ return (self.pipeline, self.epilogue, self.scheduler) not in self._UNSUPPORTED
+
+
+# ============================================================================
+# Type Mappings (centralized for both GEMM and conv codegen)
+# ============================================================================
+
+
+class CommonTypeMappings:
+ """Centralized type mappings shared by GEMM and grouped conv codegen."""
+
+ DTYPE_TO_CK = {
+ "fp16": "fp16_t",
+ "bf16": "bf16_t",
+ "fp32": "float",
+ "fp8": "fp8_t",
+ "bf8": "bf8_t",
+ "int8": "int8_t",
+ }
+
+ DTYPE_TO_CK_QUALIFIED = {
+ "fp16": "ck_tile::fp16_t",
+ "bf16": "ck_tile::bf16_t",
+ "fp32": "float",
+ "fp8": "ck_tile::fp8_t",
+ "bf8": "ck_tile::bf8_t",
+ "int8": "int8_t",
+ }
+
+ DTYPE_TO_DISPATCHER = {
+ "fp16": "DataType::FP16",
+ "bf16": "DataType::BF16",
+ "fp32": "DataType::FP32",
+ "fp8": "DataType::FP8",
+ "bf8": "DataType::BF8",
+ "int8": "DataType::INT8",
+ }
+
+ # GEMM-specific layout mappings ("r"/"c" for row/column major).
+ # Convolution layouts (NHWGC, GKYXC, etc.) are handled by
+ # unified_grouped_conv_codegen.py via GroupedConvLayout / GroupedConvTypeMappings.
+ GEMM_LAYOUT_TO_CK = {
+ "r": "tensor_layout::gemm::RowMajor",
+ "c": "tensor_layout::gemm::ColumnMajor",
+ }
+ LAYOUT_TO_CK = GEMM_LAYOUT_TO_CK # backward compat alias
+
+ GEMM_LAYOUT_TO_DISPATCHER = {
+ "r": "LayoutTag::RowMajor",
+ "c": "LayoutTag::ColMajor",
+ }
+ LAYOUT_TO_DISPATCHER = GEMM_LAYOUT_TO_DISPATCHER # backward compat alias
+
+ # GEMM-only pipeline mappings (used by unified_gemm_codegen.py).
+ # Convolution pipelines are in GroupedConvTypeMappings
+ # (unified_grouped_conv_codegen.py). CK Tile conv supports:
+ # BASIC_V1, Mem, CompV3, CompV4, CompV5, CompV6, ASYNC_V1, ASYNC_V4.
+ # The dispatcher currently generates: mem, compv3, compv4.
+ # preshufflev2 is GEMM-only (weight pre-shuffle for GEMM, not conv).
+ PIPELINE_TO_CK = {
+ "mem": "GemmPipelineAgBgCrMem",
+ "compv3": "GemmPipelineAgBgCrCompV3",
+ "compv4": "GemmPipelineAgBgCrCompV4",
+ "compv5": "GemmPipelineAgBgCrCompV5",
+ "preshufflev2": "WeightPreshufflePipelineAGmemBGmemCRegV2",
+ }
+
+ PIPELINE_TO_BASE = {
+ "mem": "BaseGemmPipelineAgBgCrMem",
+ "compv3": "BaseGemmPipelineAgBgCrCompV3",
+ "compv4": "BaseGemmPipelineAgBgCrCompV4",
+ "compv5": "BaseGemmPipelineAgBgCrCompV5",
+ "preshufflev2": "BaseWeightPreshufflePipelineAGmemBGmemCRegV2",
+ }
+
+ PIPELINE_TO_DISPATCHER = {
+ "mem": "Pipeline::Mem",
+ "compv3": "Pipeline::CompV3",
+ "compv4": "Pipeline::CompV4",
+ "compv5": "Pipeline::CompV5",
+ "preshufflev2": "Pipeline::PreShuffleV2",
+ }
+
+ SCHEDULER_TO_CK = {
+ "intrawave": "GemmPipelineScheduler::Intrawave",
+ "interwave": "GemmPipelineScheduler::Interwave",
+ "default": "GemmPipelineScheduler::Default",
+ }
+
+ SCHEDULER_TO_DISPATCHER = {
+ "intrawave": "Scheduler::Intrawave",
+ "interwave": "Scheduler::Interwave",
+ "default": "Scheduler::Auto",
+ }
+
+ EPILOGUE_TO_DISPATCHER = {
+ "cshuffle": "Epilogue::CShuffle",
+ "default": "Epilogue::Default",
+ }
+
+ @staticmethod
+ def get_output_dtype(dtype: str) -> str:
+ """Get output datatype (fp8/bf8 -> fp16)."""
+ return "fp16" if dtype in ("fp8", "bf8") else dtype
+
+
+# ============================================================================
+# Code Generation Helpers
+# ============================================================================
+
+
+def generate_cpp_compilation_unit(kernel_name: str) -> str:
+ """Generate a .cpp compilation unit that includes a kernel header.
+
+ This is the standard pattern: one .cpp per kernel that just includes
+ the generated .hpp header, causing template instantiation.
+ """
+ return (
+ f"// Auto-generated compilation unit for {kernel_name}\n"
+ f'#include "{kernel_name}.hpp"\n'
+ )
+
+
+def parallel_generate(
+ generate_fn: Callable[[T], R],
+ items: Sequence[T],
+ parallel: bool = True,
+) -> List[R]:
+ """Run ``generate_fn`` over ``items``, optionally in parallel.
+
+ Logs per-item progress (best-of-conv pattern).
+ Returns a flat list of results in completion order.
+ """
+ results: List[R] = []
+ if not items:
+ return results
+
+ if parallel and len(items) > 1:
+ with concurrent.futures.ThreadPoolExecutor() as executor:
+ futures = {executor.submit(generate_fn, item): item for item in items}
+ for future in concurrent.futures.as_completed(futures):
+ result = future.result()
+ results.append(result)
+ log.info("Generated: %s", futures[future])
+ else:
+ for item in items:
+ result = generate_fn(item)
+ results.append(result)
+ log.info("Generated: %s", item)
+
+ return results
+
+
+# ============================================================================
+# Arch-Aware Expansion Helpers (adopted from conv kernel_decl.hpp)
+# ============================================================================
+
+# These load from arch_specs_generated when available, falling back to
+# hardcoded defaults that match the most common arch (gfx942).
+
+_arch_data_cache: Optional[Dict] = None
+
+
+def _get_arch_data() -> Dict:
+ """Load arch filter data, with caching."""
+ global _arch_data_cache
+ if _arch_data_cache is not None:
+ return _arch_data_cache
+
+ try:
+ from arch_specs_generated import (
+ WARP_SUPPORTED_COMBINATIONS,
+ WARP_TILE_SUPPORTED_COMBINATIONS,
+ TRAIT_UNSUPPORTED_COMBINATIONS,
+ get_supported_archs,
+ )
+
+ _arch_data_cache = {
+ "warp_combos": WARP_SUPPORTED_COMBINATIONS,
+ "warp_tile_combos": WARP_TILE_SUPPORTED_COMBINATIONS,
+ "trait_unsupported": TRAIT_UNSUPPORTED_COMBINATIONS,
+ "supported_archs": get_supported_archs(),
+ }
+ except ImportError:
+ _arch_data_cache = {
+ "warp_combos": {
+ "gfx942": [[1, 4, 1], [2, 2, 1], [4, 1, 1]],
+ "gfx90a": [[1, 4, 1], [2, 2, 1], [4, 1, 1]],
+ },
+ "warp_tile_combos": {
+ "gfx942": {"fp16_fp16_fp32": [[16, 16, 16], [32, 32, 16]]},
+ "gfx90a": {"fp16_fp16_fp32": [[16, 16, 16], [32, 32, 16]]},
+ },
+ "trait_unsupported": {
+ ("compv3", "cshuffle", "interwave"),
+ ("compv4", "cshuffle", "interwave"),
+ },
+ "supported_archs": ["gfx90a", "gfx942", "gfx950"],
+ }
+ return _arch_data_cache
+
+
+def valid_wave_configs(arch: str) -> List[List[int]]:
+ """Return valid [wave_m, wave_n, wave_k] combos for *arch*."""
+ data = _get_arch_data()
+ return data["warp_combos"].get(arch, [[2, 2, 1]])
+
+
+def valid_warp_configs(arch: str, dtype: str) -> List[List[int]]:
+ """Return valid [warp_tile_m, warp_tile_n, warp_tile_k] combos for *arch*/*dtype*.
+
+ The dtype key is constructed as ``{dtype}_{dtype}_{acc}`` where acc is
+ fp32 for float types and int32 for int8.
+ """
+ data = _get_arch_data()
+ acc = "int32" if dtype == "int8" else "fp32"
+ dtype_key = f"{dtype}_{dtype}_{acc}"
+ arch_tiles = data["warp_tile_combos"].get(arch, {})
+ return arch_tiles.get(dtype_key, [[32, 32, 16]])
+
+
+def valid_trait_configs() -> List[Tuple[str, str]]:
+ """Return valid (pipeline, scheduler) pairs.
+
+ Compute pipelines only support intrawave; mem supports both.
+ """
+ return [
+ ("compv3", "intrawave"),
+ ("compv4", "intrawave"),
+ ("compv5", "intrawave"),
+ ("mem", "intrawave"),
+ ("mem", "interwave"),
+ ]
+
+
+def needs_wave_expansion(config: dict) -> bool:
+ """True if wave_m or wave_n is a wildcard (ANY_INT = -1)."""
+ return config.get("wave_m", 2) == ANY_INT or config.get("wave_n", 2) == ANY_INT
+
+
+def needs_warp_expansion(config: dict) -> bool:
+ """True if warp_m or warp_n is a wildcard (ANY_INT = -1)."""
+ return config.get("warp_m", 32) == ANY_INT or config.get("warp_n", 32) == ANY_INT
+
+
+def needs_pipeline_expansion(config: dict) -> bool:
+ """True if pipeline is a wildcard (\"*\")."""
+ return config.get("pipeline", "compv4") == "*"
diff --git a/dispatcher/codegen/generate_dispatcher_registration.py b/dispatcher/codegen/generate_dispatcher_registration.py
index 024ec4a7c8..8e8b67376c 100644
--- a/dispatcher/codegen/generate_dispatcher_registration.py
+++ b/dispatcher/codegen/generate_dispatcher_registration.py
@@ -109,7 +109,7 @@ inline void register_all_kernels()
"""
output_file.write_text(content)
- print(f"✓ Generated registration header: {output_file}")
+ print(f"OK Generated registration header: {output_file}")
def generate_registration_cpp(kernels: List[KernelConfig], output_file: Path):
@@ -143,7 +143,7 @@ namespace generated {
"""
output_file.write_text(content)
- print(f"✓ Generated registration implementation: {output_file}")
+ print(f"OK Generated registration implementation: {output_file}")
def generate_kernel_wrapper_header(kernel: KernelConfig, output_dir: Path):
@@ -414,8 +414,8 @@ def main():
with open(manifest_output, "w") as f:
json.dump(manifest_data, f, indent=2)
- print(f"✓ Generated manifest: {manifest_output}")
- print("\n✓ Registration code generation complete!")
+ print(f"OK Generated manifest: {manifest_output}")
+ print("\nOK Registration code generation complete!")
print(f" Total kernels: {len(kernels)}")
print(" Output files:")
print(f" - {registration_header}")
diff --git a/dispatcher/codegen/generate_kernel_wrappers.py b/dispatcher/codegen/generate_kernel_wrappers.py
index 53a9bff3ed..e11bd7a0a5 100644
--- a/dispatcher/codegen/generate_kernel_wrappers.py
+++ b/dispatcher/codegen/generate_kernel_wrappers.py
@@ -17,10 +17,10 @@ Usage:
Output structure:
build/kernel_wrappers/
- ├── gemm_fp16_rcr_128x128x32.cpp
- ├── gemm_fp16_rcr_256x256x64.cpp
- ├── conv_fwd_fp16_2d_128x128.cpp
- └── ...
+ |---- gemm_fp16_rcr_128x128x32.cpp
+ |---- gemm_fp16_rcr_256x256x64.cpp
+ |---- conv_fwd_fp16_2d_128x128.cpp
+ +---- ...
Each .cpp simply includes its corresponding .hpp and forces symbol emission.
"""
diff --git a/dispatcher/codegen/kernel_config_loader.py b/dispatcher/codegen/kernel_config_loader.py
index 537fc40581..980b4e5fd0 100644
--- a/dispatcher/codegen/kernel_config_loader.py
+++ b/dispatcher/codegen/kernel_config_loader.py
@@ -359,8 +359,8 @@ class ConvTraitConfig:
@dataclass
-class ConvKernelConfig:
- """Complete convolution kernel configuration"""
+class GroupedConvKernelConfig:
+ """Complete grouped convolution kernel configuration"""
tile: ConvTileConfig = field(default_factory=ConvTileConfig)
trait: ConvTraitConfig = field(default_factory=ConvTraitConfig)
@@ -419,7 +419,11 @@ class ConvKernelConfig:
def kernel_name(self) -> str:
"""Generate kernel name from config"""
- variant_map = {"forward": "fwd", "bwd_data": "bwdd", "bwd_weight": "bwdw"}
+ variant_map = {
+ "forward": "fwd",
+ "bwd_data": "bwd_data",
+ "bwd_weight": "bwd_weight",
+ }
var_str = variant_map.get(self.variant, self.variant)
name = f"conv_{var_str}_{self.dtype_input}_{self.ndim}d"
@@ -433,11 +437,11 @@ class ConvKernelConfig:
@dataclass
-class ConvKernelConfigSet:
+class GroupedConvKernelConfigSet:
"""A set of convolution kernel configurations loaded from JSON"""
name: str = "default"
- configs: List[ConvKernelConfig] = field(default_factory=list)
+ configs: List[GroupedConvKernelConfig] = field(default_factory=list)
# Tile parameter ranges
tile_m_values: List[int] = field(default_factory=lambda: [128])
@@ -481,7 +485,7 @@ class ConvKernelConfigSet:
layout: str = "nhwgc"
gpu_targets: List[str] = field(default_factory=lambda: ["gfx942"])
- def generate_configs(self) -> Iterator[ConvKernelConfig]:
+ def generate_configs(self) -> Iterator[GroupedConvKernelConfig]:
"""Generate all kernel configurations (cartesian product)"""
# Tile parameters
tile_params = itertools.product(
@@ -548,7 +552,7 @@ class ConvKernelConfigSet:
double_smem_buffer=trait[6],
num_groups_to_merge=trait[7],
)
- yield ConvKernelConfig(
+ yield GroupedConvKernelConfig(
tile=tile_cfg,
trait=trait_cfg,
dtype_input=self.dtype_input,
@@ -599,7 +603,9 @@ class ConvKernelConfigSet:
return tile_count * trait_count * extra_count * len(self.gpu_targets)
-def load_conv_kernel_configs(json_path: str | Path) -> ConvKernelConfigSet:
+def load_grouped_conv_kernel_configs(
+ json_path: str | Path,
+) -> GroupedConvKernelConfigSet:
"""
Load convolution kernel configurations from a JSON file.
@@ -607,14 +613,14 @@ def load_conv_kernel_configs(json_path: str | Path) -> ConvKernelConfigSet:
json_path: Path to JSON configuration file
Returns:
- ConvKernelConfigSet with all parameter values loaded
+ GroupedConvKernelConfigSet with all parameter values loaded
"""
json_path = Path(json_path)
with open(json_path) as f:
data = json.load(f)
- config_set = ConvKernelConfigSet()
+ config_set = GroupedConvKernelConfigSet()
# Name
config_set.name = data.get("kernel_set_name", json_path.stem)
@@ -680,15 +686,15 @@ def load_conv_kernel_configs(json_path: str | Path) -> ConvKernelConfigSet:
def generate_cpp_conv_kernel_set_declaration(
- config_set: ConvKernelConfigSet,
+ config_set: GroupedConvKernelConfigSet,
set_name: Optional[str] = None,
) -> str:
"""
- Generate C++ DECL_CONV_KERNEL_SET code from a ConvKernelConfigSet.
+ Generate C++ DECL_GROUPED_CONV_KERNEL_SET code from a GroupedConvKernelConfigSet.
"""
name = set_name or config_set.name
- lines = [f"DECL_CONV_KERNEL_SET({name},"]
+ lines = [f"DECL_GROUPED_CONV_KERNEL_SET({name},"]
for config in config_set.generate_configs():
line = f' .add("{config.dtype_input}", "{config.variant}", {config.ndim}, '
diff --git a/dispatcher/codegen/unified_gemm_codegen.py b/dispatcher/codegen/unified_gemm_codegen.py
index b0dd961be7..a818cec83e 100755
--- a/dispatcher/codegen/unified_gemm_codegen.py
+++ b/dispatcher/codegen/unified_gemm_codegen.py
@@ -7,7 +7,7 @@
Unified GEMM Code Generator - Single Source of Truth
This is THE unified code generator for all GEMM kernel variants:
-- Standard GEMM (C = A × B)
+- Standard GEMM (C = A x B)
- Preshuffle GEMM (optimized weight access)
- Multi-D GEMM (element-wise fusion)
@@ -25,6 +25,12 @@ from dataclasses import dataclass, asdict
from enum import Enum
import concurrent.futures
+from codegen_common import (
+ TileConfig,
+ TraitConfigBase,
+ CommonTypeMappings as TypeMappings,
+)
+
# Import architecture filter for GPU-specific validation
try:
from arch_filter import ArchFilter, KernelConfig as ArchKernelConfig, OperatorType
@@ -194,62 +200,14 @@ class GemmVariant(Enum):
MULTI_D = "multi_d"
-@dataclass
-class TileConfig:
- """Tile configuration parameters"""
-
- tile_m: int
- tile_n: int
- tile_k: int
- warp_m: int
- warp_n: int
- warp_k: int
- warp_tile_m: int
- warp_tile_n: int
- warp_tile_k: int
-
- def is_valid(self) -> bool:
- """Validate tile configuration"""
- return (
- self.tile_m % (self.warp_m * self.warp_tile_m) == 0
- and self.tile_n % (self.warp_n * self.warp_tile_n) == 0
- and self.tile_k % (self.warp_k * self.warp_tile_k) == 0
- and self.tile_m > 0
- and self.tile_n > 0
- and self.tile_k > 0
- )
+# TileConfig imported from codegen_common
@dataclass
-class TraitConfig:
- """Kernel trait configuration"""
+class TraitConfig(TraitConfigBase):
+ """GEMM-specific trait configuration extending TraitConfigBase with persistent mode."""
- pipeline: str # mem, compv3, compv4
- epilogue: str # default, cshuffle
- scheduler: str # intrawave, interwave
- pad_m: bool
- pad_n: bool
- pad_k: bool
- persistent: bool
-
- def is_valid(self) -> bool:
- """Check if trait combination is valid"""
- # Unsupported combinations
- # Only 'mem' pipeline supports interwave scheduler.
- # All compute pipelines (compv3/v4/v5/v6/async) only support intrawave.
- unsupported = {
- ("compv3", "cshuffle", "interwave"),
- ("compv3", "default", "interwave"),
- ("compv4", "cshuffle", "interwave"),
- ("compv4", "default", "interwave"),
- ("compv5", "cshuffle", "interwave"),
- ("compv5", "default", "interwave"),
- ("compv6", "cshuffle", "interwave"),
- ("compv6", "default", "interwave"),
- ("comp_async", "cshuffle", "interwave"),
- ("comp_async", "default", "interwave"),
- }
- return (self.pipeline, self.epilogue, self.scheduler) not in unsupported
+ persistent: bool = False
@dataclass
@@ -345,89 +303,7 @@ class KernelConfig:
# ============================================================================
-class TypeMappings:
- """Centralized type mappings for code generation"""
-
- DTYPE_TO_CK = {
- "fp16": "fp16_t",
- "bf16": "bf16_t",
- "fp32": "float",
- "fp8": "fp8_t",
- "bf8": "bf8_t",
- "int8": "int8_t",
- }
-
- # Fully-qualified types for use outside of 'using namespace ck_tile' scope
- DTYPE_TO_CK_QUALIFIED = {
- "fp16": "ck_tile::fp16_t",
- "bf16": "ck_tile::bf16_t",
- "fp32": "float", # Built-in type, no namespace
- "fp8": "ck_tile::fp8_t",
- "bf8": "ck_tile::bf8_t",
- "int8": "int8_t", # Built-in type
- }
-
- DTYPE_TO_DISPATCHER = {
- "fp16": "DataType::FP16",
- "bf16": "DataType::BF16",
- "fp32": "DataType::FP32",
- "fp8": "DataType::FP8",
- "bf8": "DataType::BF8",
- "int8": "DataType::INT8",
- }
-
- LAYOUT_TO_CK = {
- "r": "tensor_layout::gemm::RowMajor",
- "c": "tensor_layout::gemm::ColumnMajor",
- }
-
- LAYOUT_TO_DISPATCHER = {
- "r": "LayoutTag::RowMajor",
- "c": "LayoutTag::ColMajor",
- }
-
- PIPELINE_TO_CK = {
- "mem": "GemmPipelineAgBgCrMem",
- "compv3": "GemmPipelineAgBgCrCompV3",
- "compv4": "GemmPipelineAgBgCrCompV4",
- "preshufflev2": "WeightPreshufflePipelineAGmemBGmemCRegV2",
- }
-
- PIPELINE_TO_BASE = {
- "mem": "BaseGemmPipelineAgBgCrMem",
- "compv3": "BaseGemmPipelineAgBgCrCompV3",
- "compv4": "BaseGemmPipelineAgBgCrCompV4",
- "preshufflev2": "BaseWeightPreshufflePipelineAGmemBGmemCRegV2",
- }
-
- PIPELINE_TO_DISPATCHER = {
- "mem": "Pipeline::Mem",
- "compv3": "Pipeline::CompV3",
- "compv4": "Pipeline::CompV4",
- "preshufflev2": "Pipeline::PreShuffleV2",
- }
-
- SCHEDULER_TO_CK = {
- "intrawave": "GemmPipelineScheduler::Intrawave",
- "interwave": "GemmPipelineScheduler::Interwave",
- "default": "GemmPipelineScheduler::Default",
- }
-
- SCHEDULER_TO_DISPATCHER = {
- "intrawave": "Scheduler::Intrawave",
- "interwave": "Scheduler::Interwave",
- "default": "Scheduler::Auto",
- }
-
- EPILOGUE_TO_DISPATCHER = {
- "cshuffle": "Epilogue::CShuffle",
- "default": "Epilogue::Default",
- }
-
- @staticmethod
- def get_output_dtype(dtype: str) -> str:
- """Get output datatype (fp8/bf8 -> fp16)"""
- return "fp16" if dtype in ["fp8", "bf8"] else dtype
+# TypeMappings imported from codegen_common as CommonTypeMappings -> TypeMappings alias
# ============================================================================
@@ -1068,7 +944,11 @@ class UnifiedGemmCodegen:
}
def generate_all(self, parallel: bool = True) -> Dict:
- """Generate all kernels"""
+ """Generate all kernels.
+
+ When parallel=True, all configs across all variants are collected first,
+ then generated concurrently in a single thread pool for maximum throughput.
+ """
log.info("Generating GEMM kernels:")
log.info(f" Datatype: {self.datatype}")
log.info(f" Layout: {self.layout}")
@@ -1078,49 +958,24 @@ class UnifiedGemmCodegen:
results = {"kernels": [], "wrappers": [], "failed": []}
- # Get configurations
+ # Collect ALL configs across all variants/preselected sets upfront
+ all_configs = []
if self.use_preselected:
- configs = self._get_preselected_configs()
- log.info(f" Total configurations: {len(configs)}")
+ all_configs = self._get_preselected_configs()
+ log.info(f" Total configurations: {len(all_configs)}")
else:
for variant in self.variants:
- log.info(f"\nGenerating {variant.value} kernels...")
configs = self._get_configs_for_variant(variant)
- log.info(f" Configurations: {len(configs)}")
+ log.info(f" {variant.value}: {len(configs)} configurations")
+ all_configs.extend(configs)
+ log.info(f" Total across all variants: {len(all_configs)}")
- if parallel:
- with concurrent.futures.ThreadPoolExecutor() as executor:
- futures = [
- executor.submit(self._generate_one, cfg) for cfg in configs
- ]
- for future in concurrent.futures.as_completed(futures):
- try:
- k, w = future.result()
- results["kernels"].append(k)
- results["wrappers"].append(w)
- except Exception as e:
- results["failed"].append(str(e))
- log.error(f"Failed: {e}")
- else:
- for cfg in configs:
- try:
- k, w = self._generate_one(cfg)
- results["kernels"].append(k)
- results["wrappers"].append(w)
- except Exception as e:
- results["failed"].append(str(e))
- log.error(f"Failed: {e}")
-
- # Generate registration header
- if results["wrappers"]:
- self._generate_registration_header(results["wrappers"])
-
- return results
-
- # Generate from preselected set
- if parallel:
+ # Generate all configs in a single parallel pass
+ if parallel and all_configs:
with concurrent.futures.ThreadPoolExecutor() as executor:
- futures = [executor.submit(self._generate_one, cfg) for cfg in configs]
+ futures = [
+ executor.submit(self._generate_one, cfg) for cfg in all_configs
+ ]
for future in concurrent.futures.as_completed(futures):
try:
k, w = future.result()
@@ -1130,7 +985,7 @@ class UnifiedGemmCodegen:
results["failed"].append(str(e))
log.error(f"Failed: {e}")
else:
- for cfg in configs:
+ for cfg in all_configs:
try:
k, w = self._generate_one(cfg)
results["kernels"].append(k)
@@ -1139,7 +994,6 @@ class UnifiedGemmCodegen:
results["failed"].append(str(e))
log.error(f"Failed: {e}")
- # Generate registration header
if results["wrappers"]:
self._generate_registration_header(results["wrappers"])
@@ -1638,12 +1492,19 @@ def main():
# Write to temp file and use as config
import tempfile
+ import os as _os
- with tempfile.NamedTemporaryFile(
+ _tmp_config = tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False
- ) as f:
- json.dump(full_config, f)
- args.config = Path(f.name)
+ )
+ try:
+ json.dump(full_config, _tmp_config)
+ _tmp_config.close()
+ args.config = Path(_tmp_config.name)
+ except Exception:
+ _tmp_config.close()
+ _os.unlink(_tmp_config.name)
+ raise
except json.JSONDecodeError as e:
logging.error(f"Invalid tile-config-json: {e}")
return 1
@@ -1672,7 +1533,7 @@ def main():
results = codegen.generate_all(parallel=not args.no_parallel)
- logging.info("\n✅ Generation complete!")
+ logging.info("\nGeneration complete.")
logging.info(f" Kernels: {len(results['kernels'])}")
logging.info(f" Wrappers: {len(results['wrappers'])}")
logging.info(f" Failed: {len(results['failed'])}")
@@ -1684,7 +1545,7 @@ def main():
# Generate dispatcher registration if requested
if args.register:
- logging.info("\n📝 Generating dispatcher registration code...")
+ logging.info("\nGenerating dispatcher registration code...")
try:
from generate_dispatcher_registration import (
scan_generated_headers,
@@ -1701,11 +1562,20 @@ def main():
)
generate_registration_cpp(kernels, reg_dir / "dispatcher_registration.cpp")
- logging.info(f"✓ Generated registration code for {len(kernels)} kernels")
+ logging.info(f"Generated registration code for {len(kernels)} kernels")
except Exception as e:
logging.error(f"Failed to generate registration code: {e}")
return 1
+ # Clean up temp config file if we created one
+ if args.tile_config_json and args.config and args.config.exists():
+ try:
+ import os as _os
+
+ _os.unlink(args.config)
+ except OSError:
+ pass
+
return 0 if not results["failed"] else 1
diff --git a/dispatcher/codegen/unified_grouped_conv_codegen.py b/dispatcher/codegen/unified_grouped_conv_codegen.py
new file mode 100644
index 0000000000..ff40cb4ed4
--- /dev/null
+++ b/dispatcher/codegen/unified_grouped_conv_codegen.py
@@ -0,0 +1,1757 @@
+#!/usr/bin/env python3
+
+# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
+# SPDX-License-Identifier: MIT
+
+"""
+Unified Grouped Convolution Code Generator
+
+This is the unified code generator for all grouped convolution kernel variants:
+- Forward grouped convolution
+- Backward data grouped convolution
+- Backward weight grouped convolution
+
+Generates both CK Tile kernels AND dispatcher wrappers.
+Based on the GEMM codegen pattern.
+"""
+
+import argparse
+import logging
+from pathlib import Path
+from typing import List, Optional, Tuple, Union
+from dataclasses import dataclass
+from enum import Enum
+
+from codegen_common import (
+ TileConfig,
+ TraitConfigBase,
+ parallel_generate,
+)
+
+logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
+log = logging.getLogger(__name__)
+
+# Import architecture filter for GPU-specific validation
+try:
+ from arch_filter import ArchFilter, OperatorType
+
+ HAS_ARCH_FILTER = True
+except ImportError:
+ HAS_ARCH_FILTER = False
+ ArchFilter = None
+ OperatorType = None
+
+
+# ============================================================================
+# Configuration and Data Structures
+# ============================================================================
+
+
+class GroupedConvVariant(Enum):
+ """Grouped convolution kernel variants"""
+
+ FORWARD = "forward"
+ BACKWARD_DATA = "bwd_data"
+ BACKWARD_WEIGHT = "bwd_weight"
+
+
+class GroupedConvLayout(Enum):
+ """Grouped convolution data layouts"""
+
+ # 1D
+ NWGC = "NWGC" # Input/Output: N W G C
+ GKXC = "GKXC" # Weight: G K X C
+ NWGK = "NWGK" # Output: N W G K
+
+ # 2D
+ NHWGC = "NHWGC" # Input: N H W G C
+ GKYXC = "GKYXC" # Weight: G K Y X C
+ NHWGK = "NHWGK" # Output: N H W G K
+
+ # 3D
+ NDHWGC = "NDHWGC" # Input: N D H W G C
+ GKZYXC = "GKZYXC" # Weight: G K Z Y X C
+ NDHWGK = "NDHWGK" # Output: N D H W G K
+
+
+@dataclass
+class GroupedConvTraitConfig(TraitConfigBase):
+ """Kernel trait configuration for grouped convolution (extends TraitConfigBase).
+
+ Conv-specific extensions beyond TraitConfigBase. These map to
+ GroupedConvTraits template parameters in grouped_convolution_utils.hpp:
+ - double_smem_buffer: ping-pong LDS for compute V4+ pipelines
+ - num_groups_to_merge: fuse multiple groups into one tile (NumGroupsToMerge)
+ - split_image: split spatial dims for large tensors (EnableSplitImage)
+ - explicit_gemm: use explicit GEMM path (ExplicitGemm)
+ - two_stage: two-stage bwd_weight with fp32 workspace + elementwise convert
+
+ Note: CK Tile already uses long_index_t (64-bit) for group strides and
+ batch offsets, so there is no separate "large_tensor" flag. For large
+ spatial dimensions, use split_image=True instead.
+ """
+
+ double_smem_buffer: bool = False
+ num_groups_to_merge: int = 1
+ split_image: bool = False
+ explicit_gemm: bool = False
+ two_stage: bool = False
+
+
+# Backward compatibility alias
+TraitConfig = GroupedConvTraitConfig
+
+
+@dataclass
+class GroupedConvKernelConfig:
+ """Complete grouped convolution kernel configuration"""
+
+ tile: TileConfig
+ trait: GroupedConvTraitConfig
+ variant: GroupedConvVariant = GroupedConvVariant.FORWARD
+ ndim_spatial: int = 2 # 1D, 2D, or 3D
+ arch: str = "gfx942" # Target architecture
+ layout: Union[str, GroupedConvLayout] = (
+ "nhwgc" # Data layout (e.g., "nhwgc", "ndhwgc")
+ )
+
+ # Vector sizes: a=4 for fp16 input (8-byte aligned global loads),
+ # b=8 for weight tensor, c=8 for output stores. These match the
+ # CK Tile default vectorization widths for fp16 on CDNA3 (gfx942).
+ vector_size_a: int = 4
+ vector_size_b: int = 8
+ vector_size_c: int = 8
+ vector_sizes: Optional[Tuple[int, int, int]] = None
+
+ # Occupancy parameters
+ block_per_cu: int = 1
+ num_wave_groups: int = 1
+ num_groups_to_merge: int = 1
+
+ # Double buffering
+ double_smem_buffer: bool = False
+
+ def __post_init__(self):
+ if self.vector_sizes is not None:
+ self.vector_size_a, self.vector_size_b, self.vector_size_c = (
+ self.vector_sizes[:3]
+ )
+ # Sync trait fields with top-level fields (trait is source of truth
+ # when both are specified, but top-level overrides default trait values).
+ if self.double_smem_buffer and not self.trait.double_smem_buffer:
+ self.trait.double_smem_buffer = self.double_smem_buffer
+ elif self.trait.double_smem_buffer:
+ self.double_smem_buffer = self.trait.double_smem_buffer
+ if self.num_groups_to_merge != 1 and self.trait.num_groups_to_merge == 1:
+ self.trait.num_groups_to_merge = self.num_groups_to_merge
+ elif self.trait.num_groups_to_merge != 1:
+ self.num_groups_to_merge = self.trait.num_groups_to_merge
+
+ def _layout_str(self) -> str:
+ """Get layout as lowercase string for naming."""
+ if hasattr(self.layout, "value"):
+ return self.layout.value.lower()
+ return str(self.layout).lower()
+
+ def name(self, datatype: str) -> str:
+ """
+ Generate kernel name that uniquely identifies the kernel configuration.
+
+ Format: grouped_conv_{variant}_{dtype}_{layout}_{ndim}d_{pipeline}_{epilogue}_{scheduler}
+ _{tile_m}x{tile_n}x{tile_k}_{warp_m}x{warp_n}x{warp_k}
+ _{warp_tile_m}x{warp_tile_n}x{warp_tile_k}
+ [_vec{a}_{b}_{c}][_bpc{n}][_wg{n}][_gm{n}][_dsb][_pad{mnk}]
+
+ All parameters that affect kernel behavior MUST be included to ensure
+ unique names for unique configurations:
+ - Variant (fwd/bwd_data/bwd_weight)
+ - Data type
+ - Layout (nhwgc, nchw, ndhwgc, etc.)
+ - Spatial dimensions (2d/3d)
+ - Pipeline, epilogue, scheduler
+ - Tile, warp, warp_tile dimensions
+ - Vector sizes, occupancy hints (if non-default)
+ - Double SMEM buffer, padding flags
+ """
+ t = self.tile
+ tr = self.trait
+ layout_str = self._layout_str()
+
+ variant_str = {
+ GroupedConvVariant.FORWARD: "fwd",
+ GroupedConvVariant.BACKWARD_DATA: "bwd_data",
+ GroupedConvVariant.BACKWARD_WEIGHT: "bwd_weight",
+ }[self.variant]
+
+ # Core identity: variant, dtype, layout, dims
+ name = (
+ f"grouped_conv_{variant_str}_{datatype}_{layout_str}_{self.ndim_spatial}d"
+ )
+
+ # Pipeline configuration
+ name += f"_{tr.pipeline}_{tr.epilogue}_{tr.scheduler}"
+
+ # Block tile dimensions (M_Tile x N_Tile x K_Tile)
+ name += f"_{t.tile_m}x{t.tile_n}x{t.tile_k}"
+
+ # Wave distribution (M_Warp x N_Warp x K_Warp)
+ name += f"_{t.warp_m}x{t.warp_n}x{t.warp_k}"
+
+ # Warp tile dimensions (M_Warp_Tile x N_Warp_Tile x K_Warp_Tile)
+ name += f"_{t.warp_tile_m}x{t.warp_tile_n}x{t.warp_tile_k}"
+
+ # Vector sizes (only if non-default)
+ if (self.vector_size_a, self.vector_size_b, self.vector_size_c) != (4, 8, 8):
+ name += (
+ f"_vec{self.vector_size_a}_{self.vector_size_b}_{self.vector_size_c}"
+ )
+
+ # Occupancy hints (only if non-default)
+ if self.block_per_cu != 1:
+ name += f"_bpc{self.block_per_cu}"
+
+ if self.num_wave_groups != 1:
+ name += f"_wg{self.num_wave_groups}"
+
+ if self.num_groups_to_merge != 1:
+ name += f"_gm{self.num_groups_to_merge}"
+
+ # Double SMEM buffer (for compute V4+)
+ if self.double_smem_buffer or tr.double_smem_buffer:
+ name += "_dsb"
+
+ # Two-stage bwd_weight (fp32 workspace + elementwise convert)
+ if tr.two_stage:
+ name += "_2stage"
+
+ # Padding suffix (only if not all enabled)
+ if not (tr.pad_m and tr.pad_n and tr.pad_k):
+ name += f"_pad{int(tr.pad_m)}{int(tr.pad_n)}{int(tr.pad_k)}"
+
+ return name
+
+ def is_valid_for_arch(self, arch: Optional[str] = None) -> bool:
+ """Check if configuration is valid for target architecture"""
+ target_arch = arch if arch is not None else self.arch
+
+ # Check trait validity
+ if not self.trait.is_valid():
+ return False
+
+ # Backward operations have stricter pipeline requirements:
+ # - Backward weight: compv4/compv5 have transpose_tile2d issues
+ # - Backward data: compv4 has get_length issues in bwd_data kernel
+ # Both backward operations ONLY support compv3 and mem pipelines
+ if self.variant in (
+ GroupedConvVariant.BACKWARD_WEIGHT,
+ GroupedConvVariant.BACKWARD_DATA,
+ ):
+ if self.trait.pipeline not in ("compv3", "mem"):
+ return False
+
+ # Check warp configuration (from arch_specs)
+ try:
+ from arch_specs_generated import WARP_SUPPORTED_COMBINATIONS
+
+ supported = WARP_SUPPORTED_COMBINATIONS.get(target_arch)
+ if supported is None:
+ return False # Unknown architecture
+ warp_cfg = [self.tile.warp_m, self.tile.warp_n, self.tile.warp_k]
+ if warp_cfg not in supported:
+ return False
+ except ImportError:
+ pass # Allow if arch_specs not available
+
+ return True
+
+
+# ============================================================================
+# Type Mappings
+# ============================================================================
+
+
+class GroupedConvTypeMappings:
+ """Centralized type mappings for grouped convolution code generation"""
+
+ DTYPE_TO_CK = {
+ "fp16": "half_t",
+ "bf16": "bf16_t",
+ "fp32": "float",
+ }
+
+ # CK Tile conv pipelines (from conv_configs.hpp PipelineTypeTraits).
+ # basic_v1/mem/compv3 use GroupedConvUniversalPipelineAgBgCrPolicy;
+ # compv4/compv5/compv6/comp_async/basic_async_v1 use their own default policy.
+ PIPELINE_TO_CK = {
+ "basic_v1": "GemmPipeline::BASIC_V1",
+ "mem": "GemmPipeline::MEMORY",
+ "compv3": "GemmPipeline::COMPUTE_V3",
+ "compv4": "GemmPipeline::COMPUTE_V4",
+ "compv5": "GemmPipeline::COMPUTE_V5",
+ "compv6": "GemmPipeline::COMPUTE_V6",
+ "comp_async": "GemmPipeline::COMPUTE_ASYNC",
+ "basic_async_v1": "GemmPipeline::BASIC_ASYNC_V1",
+ }
+
+ SCHEDULER_TO_CK = {
+ "intrawave": "GemmPipelineScheduler::Intrawave",
+ "interwave": "GemmPipelineScheduler::Interwave",
+ }
+
+ LAYOUT_1D = {
+ "in": "tensor_layout::convolution::NWGC",
+ "wei": "tensor_layout::convolution::GKXC",
+ "out": "tensor_layout::convolution::NWGK",
+ }
+
+ LAYOUT_2D = {
+ "in": "tensor_layout::convolution::NHWGC",
+ "wei": "tensor_layout::convolution::GKYXC",
+ "out": "tensor_layout::convolution::NHWGK",
+ }
+
+ LAYOUT_3D = {
+ "in": "tensor_layout::convolution::NDHWGC",
+ "wei": "tensor_layout::convolution::GKZYXC",
+ "out": "tensor_layout::convolution::NDHWGK",
+ }
+
+ @classmethod
+ def get_layouts(cls, ndim: int) -> dict:
+ if ndim == 1:
+ return cls.LAYOUT_1D
+ elif ndim == 2:
+ return cls.LAYOUT_2D
+ else:
+ return cls.LAYOUT_3D
+
+
+# ============================================================================
+# CK Tile Grouped Conv Kernel Generator
+# ============================================================================
+
+
+class CKTileGroupedConvKernelGenerator:
+ """Generates CK Tile grouped convolution kernel instance code"""
+
+ def __init__(
+ self,
+ datatype: str,
+ variant: GroupedConvVariant = GroupedConvVariant.FORWARD,
+ ):
+ self.datatype = datatype
+ self.variant = variant
+ self.tm = GroupedConvTypeMappings()
+
+ def generate(self, config: GroupedConvKernelConfig) -> str:
+ """Generate complete CK Tile grouped convolution kernel"""
+ kernel_name = config.name(self.datatype)
+ return f"""{self._header(kernel_name, config)}
+{self._config_struct(config, kernel_name)}
+{self._kernel_instance(config, kernel_name)}
+"""
+
+ def _header(self, kernel_name: str, config: GroupedConvKernelConfig) -> str:
+ """Generate header includes based on variant"""
+ if self.variant == GroupedConvVariant.BACKWARD_DATA:
+ kernel_header = "grouped_convolution_backward_data_kernel.hpp"
+ elif self.variant == GroupedConvVariant.BACKWARD_WEIGHT:
+ kernel_header = "grouped_convolution_backward_weight_kernel.hpp"
+ else:
+ kernel_header = "grouped_convolution_forward_kernel.hpp"
+
+ elementwise_include = ""
+ if config.trait.two_stage:
+ elementwise_include = '\n#include "ck_tile/ops/elementwise.hpp"'
+
+ return f"""// SPDX-License-Identifier: MIT
+// Auto-generated CK Tile Grouped Convolution kernel: {kernel_name}
+// Variant: {self.variant.value}
+#pragma once
+
+#include
+#include
+#include
+#include "ck_tile/core.hpp"
+#include "ck_tile/host/kernel_launch.hpp"
+#include "ck_tile/ops/gemm.hpp"
+#include "ck_tile/ops/grouped_convolution.hpp"
+#include "ck_tile/ops/epilogue.hpp"
+#include "ck_tile/ops/grouped_convolution/kernel/{kernel_header}"
+#include "ck_tile/ops/grouped_convolution/pipeline/grouped_conv_universal_pipeline_ag_bg_cr_policy.hpp"{elementwise_include}
+
+using namespace ck_tile;
+"""
+
+ def _config_struct(self, config: GroupedConvKernelConfig, kernel_name: str) -> str:
+ """Generate config struct"""
+ t = config.tile
+ tr = config.trait
+ layouts = self.tm.get_layouts(config.ndim_spatial)
+
+ return f"""
+// Kernel configuration
+struct {kernel_name}_Config {{
+ // Data types
+ using InDataType = {self.tm.DTYPE_TO_CK[self.datatype]};
+ using WeiDataType = {self.tm.DTYPE_TO_CK[self.datatype]};
+ using AccDataType = float;
+ using OutDataType = {self.tm.DTYPE_TO_CK[self.datatype]};
+
+ // Layouts
+ using InLayout = {layouts["in"]};
+ using WeiLayout = {layouts["wei"]};
+ using OutLayout = {layouts["out"]};
+
+ // Tile shape
+ static constexpr index_t M_Tile = {t.tile_m};
+ static constexpr index_t N_Tile = {t.tile_n};
+ static constexpr index_t K_Tile = {t.tile_k};
+
+ static constexpr index_t M_Warp = {t.warp_m};
+ static constexpr index_t N_Warp = {t.warp_n};
+ static constexpr index_t K_Warp = {t.warp_k};
+
+ static constexpr index_t M_Warp_Tile = {t.warp_tile_m};
+ static constexpr index_t N_Warp_Tile = {t.warp_tile_n};
+ static constexpr index_t K_Warp_Tile = {t.warp_tile_k};
+
+ // Vector sizes
+ static constexpr index_t VectorSizeA = {config.vector_size_a};
+ static constexpr index_t VectorSizeB = {config.vector_size_b};
+ static constexpr index_t VectorSizeC = {config.vector_size_c};
+
+ // Padding
+ static constexpr bool kPadM = {str(tr.pad_m).lower()};
+ static constexpr bool kPadN = {str(tr.pad_n).lower()};
+ static constexpr bool kPadK = {str(tr.pad_k).lower()};
+
+ // Pipeline & Epilogue
+ static constexpr auto Pipeline = {self.tm.PIPELINE_TO_CK[tr.pipeline]};
+ static constexpr auto Scheduler = {self.tm.SCHEDULER_TO_CK[tr.scheduler]};
+ static constexpr bool DoubleSmemBuffer = {str(tr.double_smem_buffer).lower()};
+ static constexpr bool UseCShuffleEpilogue = {str(tr.epilogue == "cshuffle").lower()};
+
+ // Other params
+ static constexpr int kBlockPerCu = {config.block_per_cu};
+ static constexpr index_t NumWaveGroups = {config.num_wave_groups};
+ static constexpr index_t NumGroupsToMerge = {tr.num_groups_to_merge};
+ static constexpr bool EnableSplitImage = {str(tr.split_image).lower()};
+ static constexpr bool ExplicitGemm = {str(tr.explicit_gemm).lower()};
+ static constexpr index_t NDimSpatial = {config.ndim_spatial};
+
+ // Target architecture
+ static constexpr const char* TargetArch = "{config.arch}";
+}};
+"""
+
+ def _kernel_instance(
+ self, config: GroupedConvKernelConfig, kernel_name: str
+ ) -> str:
+ """Generate kernel instantiation code with launch function"""
+ tr = config.trait
+
+ if self.variant == GroupedConvVariant.BACKWARD_WEIGHT and tr.two_stage:
+ return self._kernel_instance_two_stage(config, kernel_name)
+
+ # Variant-specific configuration
+ if self.variant == GroupedConvVariant.BACKWARD_DATA:
+ host_args_type = "GroupedConvBwdDataHostArgs"
+ kernel_type = "GroupedConvolutionBackwardDataKernel"
+ gemm_traits = "GroupedConvImplicitGemmTraitsBwdData"
+ layout_suffix = "BwdData"
+ # For bwd_data: A=dOutput, B=Weight, C=dInput
+ a_dtype = "OutDataType"
+ b_dtype = "WeiDataType"
+ c_dtype = "InDataType"
+ gemm_k_calc = "args.K_ * std::accumulate(args.filter_spatial_lengths_.begin(), args.filter_spatial_lengths_.end()"
+ direction_prefix = "BWD_DATA"
+ launcher_alias = "SelectedConvBwdDataLauncher"
+ elif self.variant == GroupedConvVariant.BACKWARD_WEIGHT:
+ host_args_type = "GroupedConvBwdWeightHostArgs"
+ kernel_type = "GroupedConvolutionBackwardWeightKernel"
+ gemm_traits = "GroupedConvImplicitGemmTraitsBwdWeight"
+ layout_suffix = "BwdWeight"
+ # For bwd_weight: A=dOutput, B=Input, C=dWeight (per CK Tile invoker)
+ a_dtype = "OutDataType"
+ b_dtype = "InDataType"
+ c_dtype = "WeiDataType"
+ gemm_k_calc = "args.N_ * std::accumulate(args.output_spatial_lengths_.begin(), args.output_spatial_lengths_.end()"
+ direction_prefix = "BWD_WEIGHT"
+ launcher_alias = "SelectedConvBwdWeightLauncher"
+ else: # Forward
+ host_args_type = "GroupedConvFwdHostArgs<>"
+ kernel_type = "GroupedConvolutionForwardKernel"
+ gemm_traits = "GroupedConvImplicitGemmTraitsFwd"
+ layout_suffix = "Fwd"
+ a_dtype = "InDataType"
+ b_dtype = "WeiDataType"
+ c_dtype = "OutDataType"
+ gemm_k_calc = "args.C_ * std::accumulate(args.filter_spatial_lengths_.begin(), args.filter_spatial_lengths_.end()"
+ direction_prefix = "FWD"
+ launcher_alias = "SelectedConvKernelLauncher"
+
+ # Create valid C++ namespace name
+ ns_name = "ns_" + kernel_name.replace("-", "_")
+
+ return f"""
+// Unique namespace for this kernel to avoid conflicts when including multiple kernels
+namespace {ns_name} {{
+
+// Bring Config into namespace
+using Config = {kernel_name}_Config;
+
+// Kernel name for identification
+constexpr const char* CONV_{direction_prefix}_KERNEL_NAME = "{kernel_name}";
+
+// Selected kernel alias
+using SelectedConv{direction_prefix.title()}Kernel = Config;
+
+// =============================================================================
+// Kernel Launch Implementation ({self.variant.value})
+// =============================================================================
+
+struct {kernel_name}_Launcher {{
+ using KernelConfig = Config; // Use the Config alias from namespace
+ using InDataType = typename Config::InDataType;
+ using WeiDataType = typename Config::WeiDataType;
+ using OutDataType = typename Config::OutDataType;
+ using AccDataType = typename Config::AccDataType;
+ using InLayout = typename Config::InLayout;
+ using WeiLayout = typename Config::WeiLayout;
+ using OutLayout = typename Config::OutLayout;
+
+ static constexpr index_t NDimSpatial = Config::NDimSpatial;
+
+ // Implicit GEMM shape
+ using GemmShape = TileGemmShape<
+ sequence,
+ sequence,
+ sequence>;
+
+ // Convolution traits
+ static constexpr auto ConvSpec = ConvolutionSpecialization::Default;
+ using GroupedConvTraitsType = GroupedConvTraits<
+ NDimSpatial, ConvSpec, InLayout, WeiLayout, tuple<>, OutLayout,
+ Config::VectorSizeA, Config::VectorSizeB, Config::VectorSizeC,
+ Config::NumGroupsToMerge, Config::EnableSplitImage, Config::ExplicitGemm>;
+
+ // Tile partitioner
+ using TilePartitioner = GemmSpatiallyLocalTilePartitioner<
+ GemmShape,
+ GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
+ GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
+
+ // Universal traits - layout suffix changes per variant
+ using GemmUniversalTraits = TileGemmUniversalTraits<
+ GroupedConvTraitsType::FixedGemmParams::kPadM,
+ GroupedConvTraitsType::FixedGemmParams::kPadN,
+ GroupedConvTraitsType::FixedGemmParams::kPadK,
+ Config::DoubleSmemBuffer,
+ typename GroupedConvTraitsType::AsLayout{layout_suffix},
+ typename GroupedConvTraitsType::BsLayout{layout_suffix},
+ typename GroupedConvTraitsType::CLayout{layout_suffix},
+ GroupedConvTraitsType::FixedGemmParams::TransposeC,
+ GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
+ GroupedConvTraitsType::FixedGemmParams::Persistent,
+ Config::NumWaveGroups>;
+
+ // Pipeline problem - data types change per variant
+ using GemmPipelineProblem = GemmPipelineProblem<
+ {a_dtype}, {b_dtype}, AccDataType, GemmShape,
+ typename GroupedConvTraitsType::template {gemm_traits},
+ element_wise::PassThrough, element_wise::PassThrough, {c_dtype},
+ GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
+ GroupedConvTraitsType::VectorSizeA, GroupedConvTraitsType::VectorSizeB>;
+
+ // Base pipeline for tail handling
+ using BaseGemmPipeline = {self._get_base_pipeline(tr.pipeline)};
+
+ static float launch(const {host_args_type}& args, const stream_config& s) {{
+ const index_t gemm_k = {gemm_k_calc}, 1, std::multiplies());
+
+ const index_t k_grain = args.k_batch * Config::K_Tile;
+ const index_t K_split = (gemm_k + k_grain - 1) / k_grain * Config::K_Tile;
+ const index_t num_loop = TilePartitioner::GetLoopNum(K_split);
+ const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
+ const TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
+
+ float ave_time{{0}};
+
+ constexpr auto scheduler = Config::Scheduler;
+
+ using UniversalGemmProblem = UniversalGemmPipelineProblem<
+ {a_dtype}, {b_dtype}, AccDataType, GemmShape, GemmUniversalTraits,
+ scheduler,
+ element_wise::PassThrough, element_wise::PassThrough, {c_dtype},
+ GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
+ GroupedConvTraitsType::VectorSizeA, GroupedConvTraitsType::VectorSizeB>;
+
+ using GemmPipeline = {self._get_pipeline_template_args(tr.pipeline, "UniversalGemmProblem")};
+
+ using ConvEpilogue = CShuffleEpilogue, AccDataType, {c_dtype},
+ typename GroupedConvTraitsType::ImplicitGemmDsLayout,
+ typename GroupedConvTraitsType::FixedGemmParams::ELayout,
+ element_wise::PassThrough,
+ TilePartitioner::MPerBlock, TilePartitioner::NPerBlock,
+ Config::M_Warp, Config::N_Warp, Config::M_Warp_Tile,
+ Config::N_Warp_Tile, Config::K_Warp_Tile,
+ GroupedConvTraitsType::FixedGemmParams::TransposeC,
+ Config::NumWaveGroups,
+ GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
+ Config::VectorSizeC, false, 1, Config::DoubleSmemBuffer>>;
+
+ using Kernel = {kernel_type}<
+ GroupedConvTraitsType, TilePartitioner, GemmPipeline, ConvEpilogue>;
+
+ const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {{
+ auto kargs = Kernel::MakeKernelArgs(args);
+
+ if (!Kernel::IsSupportedArgument(kargs)) {{
+ throw std::runtime_error("Arguments not supported for grouped conv kernel");
+ }}
+
+ const dim3 grids = Kernel::GridSize(kargs);
+ const dim3 blocks = Kernel::BlockSize();
+
+ ave_time = launch_kernel(s, make_kernel(
+ Kernel{{}}, grids, blocks, 0, kargs));
+
+ return ave_time;
+ }};
+
+ BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
+ return ave_time;
+ }}
+}};
+
+// Launcher alias for tile_engine compatibility
+using {launcher_alias} = {kernel_name}_Launcher;
+
+}} // namespace {ns_name}
+
+// Export specific launcher to global namespace
+using {kernel_name}_Launcher = {ns_name}::{kernel_name}_Launcher;
+
+// When used with -include compiler flag, export aliases to global namespace
+#ifdef CK_TILE_SINGLE_KERNEL_INCLUDE
+using {launcher_alias} = {ns_name}::{launcher_alias};
+constexpr const char* CONV_{direction_prefix}_KERNEL_NAME = {ns_name}::CONV_{direction_prefix}_KERNEL_NAME;
+#endif
+"""
+
+ # Pipelines that accept GroupedConvUniversalPipelineAgBgCrPolicy
+ # as a second template parameter for conv-specific LDS layout.
+ # (from conv_configs.hpp PipelineTypeTraits -- basic_v1/mem/compv3)
+ # CompV4/V5/V6/comp_async/basic_async_v1 use their own default policies.
+ _CONV_POLICY_PIPELINES = {"basic_v1", "mem", "compv3"}
+
+ def _get_pipeline(self, pipeline: str) -> str:
+ """Get pipeline class name."""
+ pipelines = {
+ "basic_v1": "GemmPipelineAGmemBGmemCRegV1",
+ "mem": "GemmPipelineAgBgCrMem",
+ "compv3": "GemmPipelineAgBgCrCompV3",
+ "compv4": "GemmPipelineAgBgCrCompV4",
+ "compv5": "GemmPipelineAgBgCrCompV5",
+ "compv6": "GemmPipelineAgBgCrCompV6",
+ "comp_async": "GemmPipelineAgBgCrCompAsync",
+ "basic_async_v1": "GemmPipelineAGmemBGmemCRegAsyncV1",
+ }
+ return pipelines.get(pipeline, "GemmPipelineAgBgCrCompV3")
+
+ def _get_pipeline_template_args(self, pipeline: str, problem_type: str) -> str:
+ """Get full template argument list for pipeline instantiation.
+
+ For basic_v1/mem/compv3, passes GroupedConvUniversalPipelineAgBgCrPolicy
+ as a second template argument for conv-specific LDS banking.
+ """
+ base = self._get_pipeline(pipeline)
+ if pipeline in self._CONV_POLICY_PIPELINES:
+ return f"{base}<{problem_type}, GroupedConvUniversalPipelineAgBgCrPolicy>"
+ return f"{base}<{problem_type}>"
+
+ def _get_base_pipeline(self, pipeline: str) -> str:
+ """Get base pipeline class name (used for tail handling only).
+
+ Note: basic_async_v1 inherits from BaseGemmPipelineAGmemBGmemCRegV1
+ (there is no separate BaseGemmPipelineAGmemBGmemCRegAsyncV1).
+ """
+ pipelines = {
+ "basic_v1": "BaseGemmPipelineAGmemBGmemCRegV1",
+ "mem": "BaseGemmPipelineAgBgCrMem",
+ "compv3": "BaseGemmPipelineAgBgCrCompV3",
+ "compv4": "BaseGemmPipelineAgBgCrCompV4",
+ "compv5": "BaseGemmPipelineAgBgCrCompV5",
+ "compv6": "BaseGemmPipelineAgBgCrCompV6",
+ "comp_async": "BaseGemmPipelineAgBgCrCompAsync",
+ "basic_async_v1": "BaseGemmPipelineAGmemBGmemCRegV1",
+ }
+ return pipelines.get(pipeline, "BaseGemmPipelineAgBgCrCompV3")
+
+ def _kernel_instance_two_stage(
+ self, config: GroupedConvKernelConfig, kernel_name: str
+ ) -> str:
+ """Generate two-stage bwd_weight kernel: GEMM into fp32 workspace + ElementWise convert.
+
+ Mirrors grouped_convolution_backward_weight_two_stage_invoker.hpp from
+ example/ck_tile/20_grouped_convolution/.
+ """
+ tr = config.trait
+ ns_name = "ns_" + kernel_name.replace("-", "_")
+ direction_prefix = "BWD_WEIGHT"
+ launcher_alias = "SelectedConvBwdWeightLauncher"
+
+ return f"""
+namespace {ns_name} {{
+
+using Config = {kernel_name}_Config;
+constexpr const char* CONV_{direction_prefix}_KERNEL_NAME = "{kernel_name}";
+using SelectedConv{direction_prefix.title()}Kernel = Config;
+
+struct {kernel_name}_Launcher {{
+ using KernelConfig = Config;
+ using InDataType = typename Config::InDataType;
+ using WeiDataType = typename Config::WeiDataType;
+ using OutDataType = typename Config::OutDataType;
+ using AccDataType = typename Config::AccDataType;
+ using InLayout = typename Config::InLayout;
+ using WeiLayout = typename Config::WeiLayout;
+ using OutLayout = typename Config::OutLayout;
+ using WorkspaceDataType = float;
+
+ static constexpr index_t NDimSpatial = Config::NDimSpatial;
+ // Two-stage forces VectorSizeC = 1 for workspace writes
+ static constexpr index_t VectorSizeC_TwoStage = 1;
+
+ using GemmShape = TileGemmShape<
+ sequence,
+ sequence,
+ sequence>;
+
+ static constexpr auto ConvSpec = ConvolutionSpecialization::Default;
+ using GroupedConvTraitsType = GroupedConvTraits<
+ NDimSpatial, ConvSpec, InLayout, WeiLayout, tuple<>, OutLayout,
+ Config::VectorSizeA, Config::VectorSizeB, VectorSizeC_TwoStage,
+ Config::NumGroupsToMerge, Config::EnableSplitImage, Config::ExplicitGemm>;
+
+ using TilePartitioner = GemmSpatiallyLocalTilePartitioner<
+ GemmShape,
+ GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
+ GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
+
+ using GemmUniversalTraits = TileGemmUniversalTraits<
+ GroupedConvTraitsType::FixedGemmParams::kPadM,
+ GroupedConvTraitsType::FixedGemmParams::kPadN,
+ GroupedConvTraitsType::FixedGemmParams::kPadK,
+ Config::DoubleSmemBuffer,
+ typename GroupedConvTraitsType::AsLayoutBwdWeight,
+ typename GroupedConvTraitsType::BsLayoutBwdWeight,
+ typename GroupedConvTraitsType::CLayoutBwdWeight,
+ GroupedConvTraitsType::FixedGemmParams::TransposeC,
+ GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
+ GroupedConvTraitsType::FixedGemmParams::Persistent,
+ Config::NumWaveGroups>;
+
+ using GemmPipelineProblem = GemmPipelineProblem<
+ OutDataType, InDataType, AccDataType, GemmShape,
+ typename GroupedConvTraitsType::template GroupedConvImplicitGemmTraitsBwdWeight,
+ element_wise::PassThrough, element_wise::PassThrough, WeiDataType,
+ GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
+ GroupedConvTraitsType::VectorSizeA, GroupedConvTraitsType::VectorSizeB>;
+
+ using BaseGemmPipeline = {self._get_base_pipeline(tr.pipeline)};
+
+ static float launch(const GroupedConvBwdWeightHostArgs& args, const stream_config& s) {{
+ const index_t gemm_k = args.N_ * std::accumulate(
+ args.output_spatial_lengths_.begin(), args.output_spatial_lengths_.end(),
+ 1, std::multiplies());
+
+ const index_t k_grain = args.k_batch * Config::K_Tile;
+ const index_t K_split = (gemm_k + k_grain - 1) / k_grain * Config::K_Tile;
+ const index_t num_loop = TilePartitioner::GetLoopNum(K_split);
+ const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
+ const TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
+
+ float ave_time{{0}};
+
+ constexpr auto scheduler = Config::Scheduler;
+
+ using UniversalGemmProblem = UniversalGemmPipelineProblem<
+ OutDataType, InDataType, AccDataType, GemmShape, GemmUniversalTraits,
+ scheduler,
+ element_wise::PassThrough, element_wise::PassThrough, WeiDataType,
+ GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
+ GroupedConvTraitsType::VectorSizeA, GroupedConvTraitsType::VectorSizeB>;
+
+ using GemmPipeline = {self._get_pipeline_template_args(tr.pipeline, "UniversalGemmProblem")};
+
+ // Epilogue writes to fp32 workspace (not fp16 output)
+ using ConvEpilogue = CShuffleEpilogue, AccDataType, WorkspaceDataType,
+ typename GroupedConvTraitsType::ImplicitGemmDsLayout,
+ typename GroupedConvTraitsType::FixedGemmParams::ELayout,
+ element_wise::PassThrough,
+ TilePartitioner::MPerBlock, TilePartitioner::NPerBlock,
+ Config::M_Warp, Config::N_Warp, Config::M_Warp_Tile,
+ Config::N_Warp_Tile, Config::K_Warp_Tile,
+ GroupedConvTraitsType::FixedGemmParams::TransposeC,
+ Config::NumWaveGroups,
+ GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
+ GroupedConvTraitsType::VectorSizeC>>;
+
+ using Kernel = GroupedConvolutionBackwardWeightKernel<
+ GroupedConvTraitsType, TilePartitioner, GemmPipeline, ConvEpilogue>;
+
+ // ElementWise kernel: fp32 workspace -> fp16/bf16 output
+ using XElementwiseOp = element_wise::UnaryConvert;
+ using EwBlockTile = sequence<2048>;
+ using EwBlockWarps = sequence<8>;
+ using EwWarpTile = sequence<64>;
+ using EwShape = ElementWiseShape;
+ using EwProblem = ElementWisePipelineProblem<
+ WorkspaceDataType, WorkspaceDataType, WeiDataType, EwShape, XElementwiseOp>;
+ using EwKernel = ElementWiseKernel;
+
+ // Workspace: G * K * C * product(filter_spatial) elements in fp32
+ const index_t spatial_accum = std::accumulate(
+ args.filter_spatial_lengths_.begin(), args.filter_spatial_lengths_.end(),
+ 1, std::multiplies());
+ DeviceMem ws_buf(args.G_ * args.K_ * args.C_ * spatial_accum * sizeof(WorkspaceDataType));
+
+ GroupedConvBwdWeightHostArgs ws_args(args);
+ auto* c_ptr = ws_args.wei_ptr;
+ ws_args.wei_ptr = ws_buf.GetDeviceBuffer();
+
+ auto kargs = Kernel::MakeKernelArgs(ws_args);
+
+ if(!Kernel::IsSupportedArgument(kargs)) {{
+ throw std::runtime_error("Arguments not supported for two-stage bwd_weight kernel");
+ }}
+
+ const dim3 grids = Kernel::GridSize(kargs);
+ const dim3 blocks = Kernel::BlockSize();
+
+ // ElementWise kernel setup
+ const index_t ew_block_size = EwKernel::BlockSize();
+ const index_t total_elems = args.G_ * args.K_ * args.C_ * spatial_accum;
+ constexpr index_t elems_per_block = EwBlockTile::at(number<0>{{}});
+ const index_t ew_grid_size = (total_elems + elems_per_block - 1) / elems_per_block;
+
+ auto ew_shape = make_tuple(args.G_ * args.K_,
+ args.C_ * spatial_accum);
+ auto ew_inputs = make_tuple(static_cast(ws_args.wei_ptr));
+
+ if(!EwKernel::IsSupportedArgument(ew_shape)) {{
+ throw std::runtime_error("ElementWise arguments not supported for two-stage convert");
+ }}
+
+ auto preprocess = [&]() {{
+ if(kargs.k_batch > 1)
+ hip_check_error(hipMemsetAsync(
+ ws_args.wei_ptr, 0,
+ total_elems * sizeof(WorkspaceDataType),
+ s.stream_id_));
+ }};
+
+ ave_time = launch_kernel_time_mask(
+ s, preprocess,
+ make_kernel(Kernel{{}}, grids, blocks, 0, kargs),
+ make_kernel(
+ EwKernel{{}}, ew_grid_size, ew_block_size, 0,
+ ew_shape,
+ make_tuple(args.C_ * spatial_accum, 1),
+ make_tuple(args.C_ * spatial_accum, 1),
+ ew_inputs,
+ static_cast(c_ptr)));
+
+ return ave_time;
+ }}
+}};
+
+using {launcher_alias} = {kernel_name}_Launcher;
+
+}} // namespace {ns_name}
+
+using {kernel_name}_Launcher = {ns_name}::{kernel_name}_Launcher;
+
+#ifdef CK_TILE_SINGLE_KERNEL_INCLUDE
+using {launcher_alias} = {ns_name}::{launcher_alias};
+constexpr const char* CONV_{direction_prefix}_KERNEL_NAME = {ns_name}::CONV_{direction_prefix}_KERNEL_NAME;
+#endif
+"""
+
+
+# ============================================================================
+# Dispatcher Wrapper Generator
+# ============================================================================
+
+
+class GroupedConvDispatcherWrapperGenerator:
+ """Generates dispatcher integration wrapper following GEMM pattern"""
+
+ # Static mappings for pipeline and scheduler enum names (matches kernel_key.hpp)
+ PIPELINE_TO_DISPATCHER = {
+ "mem": "Pipeline::Mem",
+ "compv3": "Pipeline::CompV3",
+ "compv4": "Pipeline::CompV4",
+ "compv5": "Pipeline::CompV5",
+ "preshufflev1": "Pipeline::PreShuffleV1",
+ "preshufflev2": "Pipeline::PreShuffleV2",
+ }
+
+ SCHEDULER_TO_DISPATCHER = {
+ "default": "Scheduler::Default",
+ "intrawave": "Scheduler::Intrawave",
+ "interwave": "Scheduler::Interwave",
+ }
+
+ def __init__(
+ self,
+ datatype: str,
+ variant: GroupedConvVariant = GroupedConvVariant.FORWARD,
+ ):
+ self.datatype = datatype
+ self.variant = variant
+
+ def _pipeline_to_dispatcher(self, pipeline: str) -> str:
+ """Convert pipeline string to dispatcher enum value"""
+ return self.PIPELINE_TO_DISPATCHER.get(
+ pipeline.lower(), f"Pipeline::{pipeline.capitalize()}"
+ )
+
+ def _scheduler_to_dispatcher(self, scheduler: str) -> str:
+ """Convert scheduler string to dispatcher enum value"""
+ return self.SCHEDULER_TO_DISPATCHER.get(
+ scheduler.lower(), f"Scheduler::{scheduler.capitalize()}"
+ )
+
+ def generate(
+ self,
+ config: GroupedConvKernelConfig,
+ kernel_path: Path,
+ output_dir: Path,
+ ) -> str:
+ """Generate dispatcher wrapper with factory function for registry"""
+ kernel_name = config.name(self.datatype)
+ rel_path = kernel_path.relative_to(output_dir)
+
+ # Determine launcher type based on variant
+ if self.variant == GroupedConvVariant.FORWARD:
+ launcher_alias = "SelectedConvKernelLauncher"
+ host_args_type = "GroupedConvFwdHostArgs<>"
+ conv_type_str = "forward"
+ elif self.variant == GroupedConvVariant.BACKWARD_DATA:
+ launcher_alias = "SelectedConvBwdDataLauncher"
+ host_args_type = "GroupedConvBwdDataHostArgs"
+ conv_type_str = "bwd_data"
+ else: # BACKWARD_WEIGHT
+ launcher_alias = "SelectedConvBwdWeightLauncher"
+ host_args_type = "GroupedConvBwdWeightHostArgs"
+ conv_type_str = "bwd_weight"
+
+ return f"""// SPDX-License-Identifier: MIT
+// Auto-generated dispatcher wrapper for: {kernel_name}
+#pragma once
+
+#include "ck_tile/dispatcher.hpp"
+#include "ck_tile/dispatcher/grouped_conv_utils.hpp"
+#include "../{rel_path}"
+
+namespace ck_tile {{
+namespace dispatcher {{
+namespace generated {{
+
+using ::ck_tile::dispatcher::GroupedConvKernelInstancePtr;
+using ::ck_tile::dispatcher::GroupedConvKernelKey;
+using ::ck_tile::dispatcher::DataType;
+using ::ck_tile::dispatcher::LayoutTag;
+using ::ck_tile::dispatcher::Pipeline;
+using ::ck_tile::dispatcher::Scheduler;
+using ::ck_tile::dispatcher::Epilogue;
+using Priority = ::ck_tile::dispatcher::GroupedConvRegistry::Priority;
+
+// Factory function to create kernel instance for registry
+inline GroupedConvKernelInstancePtr make_{kernel_name}(const std::string& gfx_arch = "gfx942") {{
+ GroupedConvKernelKey key;
+ key.signature.dtype_in = DataType::FP16;
+ key.signature.dtype_wei = DataType::FP16;
+ key.signature.dtype_out = DataType::FP16;
+ key.signature.dtype_acc = DataType::FP32;
+ key.signature.layout = "nhwgc";
+ key.signature.conv_type = "{conv_type_str}";
+ key.signature.num_dims = {config.ndim_spatial};
+ key.signature.groups = 1;
+
+ key.algorithm.tile_shape = {{{config.tile.tile_m}, {config.tile.tile_n}, {config.tile.tile_k}}};
+ key.algorithm.wave_shape = {{{config.tile.warp_m}, {config.tile.warp_n}, 1}};
+ key.algorithm.warp_tile_shape = {{{config.tile.warp_tile_m}, {config.tile.warp_tile_n}, {config.tile.warp_tile_k}}};
+ key.algorithm.pipeline = {self._pipeline_to_dispatcher(config.trait.pipeline)};
+ key.algorithm.scheduler = {self._scheduler_to_dispatcher(config.trait.scheduler)};
+ key.algorithm.epilogue = Epilogue::CShuffle;
+ key.gfx_arch = gfx_arch;
+
+ // Create kernel instance that wraps the launcher
+ return std::make_shared(
+ key,
+ "{kernel_name}",
+ []({host_args_type}& args, const stream_config& cfg) -> float {{
+ return {kernel_name}_Launcher::launch(args, cfg);
+ }}
+ );
+}}
+
+}} // namespace generated
+}} // namespace dispatcher
+}} // namespace ck_tile
+
+// Export launcher alias to global namespace for direct use
+using {launcher_alias} = {kernel_name}_Launcher;
+"""
+
+
+# ============================================================================
+# Configuration Parser
+# ============================================================================
+
+
+def get_default_configs(
+ arch: str = "gfx942",
+ variants: Optional[List[GroupedConvVariant]] = None,
+ ndims: Optional[List[int]] = None,
+) -> List[GroupedConvKernelConfig]:
+ """Get default grouped convolution configurations for target architecture"""
+ configs = []
+
+ if variants is None:
+ variants = [GroupedConvVariant.FORWARD]
+ if ndims is None:
+ ndims = [2]
+
+ # Valid configurations per variant (based on CK Tile example configs)
+ # Forward and Backward Data: standard GEMM-like tiles
+ fwd_bwd_data_tiles = [
+ # (tile_m, tile_n, tile_k, warp_m, warp_n, warp_tile_m, warp_tile_n, warp_tile_k)
+ (128, 128, 32, 2, 2, 32, 32, 16), # Standard 128x128
+ (256, 256, 32, 2, 2, 32, 32, 16), # Large 256x256
+ (64, 64, 32, 1, 4, 16, 16, 16), # Small 64x64
+ (128, 64, 32, 2, 2, 32, 32, 16), # Rectangular
+ (16, 64, 64, 1, 4, 16, 16, 32), # Tall and narrow
+ ]
+
+ # Backward Weight: VERY specific tile configs that work with CK Tile's bwd_weight kernel
+ # Based on ConvConfigComputeV3 from CK Tile examples (example/ck_tile/20_grouped_convolution/)
+ # Note: Backward weight has strict constraints on warp configurations due to transpose_tile2d
+ # Only specific warp configs work: (1, 4, 1) and (4, 1, 1) are known to work
+ bwd_weight_tiles = [
+ # (tile_m, tile_n, tile_k, warp_m, warp_n, warp_tile_m, warp_tile_n, warp_tile_k)
+ # ConvConfigComputeV3: The primary working config for backward weight
+ (16, 64, 64, 1, 4, 16, 16, 32),
+ ]
+
+ for variant in variants:
+ # Select tile configs based on variant
+ if variant == GroupedConvVariant.BACKWARD_WEIGHT:
+ tile_configs = bwd_weight_tiles
+ # Backward weight ONLY supports compv3 (compv4/compv5 have transpose_tile2d issues)
+ pipelines = [("compv3", "cshuffle")]
+ # Also generate two-stage variants (fp32 workspace + elementwise convert)
+ two_stage_flags = [False, True]
+ elif variant == GroupedConvVariant.BACKWARD_DATA:
+ tile_configs = fwd_bwd_data_tiles
+ # Backward data ONLY supports compv3 (compv4 has get_length issues in bwd_data kernel)
+ pipelines = [("compv3", "cshuffle")]
+ two_stage_flags = [False]
+ else:
+ tile_configs = fwd_bwd_data_tiles
+ # Only forward grouped convolution supports both compv3 and compv4
+ pipelines = [("compv3", "cshuffle"), ("compv4", "cshuffle")]
+ two_stage_flags = [False]
+ for ndim in ndims:
+ for pipeline, epilogue in pipelines:
+ for (
+ tile_m,
+ tile_n,
+ tile_k,
+ warp_m,
+ warp_n,
+ warp_tile_m,
+ warp_tile_n,
+ warp_tile_k,
+ ) in tile_configs:
+ for two_stage in two_stage_flags:
+ adj_tile_k = tile_k * 2 if pipeline == "compv4" else tile_k
+
+ trait = GroupedConvTraitConfig(
+ pipeline=pipeline,
+ scheduler="intrawave",
+ epilogue=epilogue,
+ double_smem_buffer=(pipeline == "compv4"),
+ pad_m=True,
+ pad_n=True,
+ pad_k=True,
+ two_stage=two_stage,
+ )
+
+ if not trait.is_valid():
+ continue
+
+ config = GroupedConvKernelConfig(
+ tile=TileConfig(
+ tile_m=tile_m,
+ tile_n=tile_n,
+ tile_k=adj_tile_k,
+ warp_m=warp_m,
+ warp_n=warp_n,
+ warp_k=1,
+ warp_tile_m=warp_tile_m,
+ warp_tile_n=warp_tile_n,
+ warp_tile_k=warp_tile_k,
+ ),
+ trait=trait,
+ variant=variant,
+ ndim_spatial=ndim,
+ arch=arch,
+ )
+
+ if config.is_valid_for_arch():
+ configs.append(config)
+
+ return configs
+
+
+def get_arch_filter():
+ """Get arch filter if available"""
+ try:
+ from arch_filter import ArchFilter
+
+ return ArchFilter
+ except ImportError:
+ return None
+
+
+# ============================================================================
+# Main Generator
+# ============================================================================
+
+
+class _GenItem:
+ """Item for parallel generation with progress logging."""
+
+ def __init__(
+ self,
+ idx: int,
+ total: int,
+ config: GroupedConvKernelConfig,
+ datatype: str,
+ variant: GroupedConvVariant,
+ ):
+ self.idx = idx
+ self.total = total
+ self.config = config
+ self.datatype = datatype
+ self.variant = variant
+
+ def __str__(self) -> str:
+ return f"kernel {self.idx}/{self.total}: {self.config.name(self.datatype)}"
+
+
+class UnifiedGroupedConvCodegen:
+ """Main grouped convolution code generator"""
+
+ def __init__(
+ self,
+ output_dir: Path,
+ gpu_target: str = "gfx942",
+ datatype: str = "fp16",
+ ndim_spatial: int = 2,
+ enable_arch_filter: bool = True,
+ ):
+ self.output_dir = output_dir
+ self.output_dir.mkdir(parents=True, exist_ok=True)
+
+ # Create wrapper directory for dispatcher integration
+ self.wrapper_dir = self.output_dir / "dispatcher_wrappers"
+ self.wrapper_dir.mkdir(parents=True, exist_ok=True)
+
+ self.generated_files: List[Path] = []
+ self.generated_wrappers: List[Path] = []
+ self.gpu_target = gpu_target
+ self.datatype = datatype
+ self.ndim_spatial = ndim_spatial
+
+ # Initialize architecture filter for GPU-specific validation
+ self.arch_filter = None
+ if enable_arch_filter and HAS_ARCH_FILTER:
+ try:
+ self.arch_filter = ArchFilter(gpu_target, strict_mode=False)
+ log.info(f"Architecture filter enabled for {gpu_target}")
+ except ValueError as e:
+ log.warning(f"Could not create arch filter: {e}")
+
+ def _get_configs(self) -> List[GroupedConvKernelConfig]:
+ """Get configurations for this codegen's datatype and ndim_spatial."""
+ return get_default_configs(
+ arch=self.gpu_target,
+ variants=[
+ GroupedConvVariant.FORWARD,
+ GroupedConvVariant.BACKWARD_DATA,
+ GroupedConvVariant.BACKWARD_WEIGHT,
+ ],
+ ndims=[self.ndim_spatial],
+ )
+
+ def _get_operator_type(
+ self, variant: GroupedConvVariant
+ ) -> Optional["OperatorType"]:
+ """Map GroupedConvVariant to OperatorType for arch validation"""
+ if OperatorType is None:
+ return None
+
+ variant_to_operator = {
+ GroupedConvVariant.FORWARD: OperatorType.CONV_FWD,
+ GroupedConvVariant.BACKWARD_DATA: OperatorType.CONV_BWD_DATA,
+ GroupedConvVariant.BACKWARD_WEIGHT: OperatorType.CONV_BWD_WEIGHT,
+ }
+ return variant_to_operator.get(variant, OperatorType.CONV_FWD)
+
+ def is_config_valid(
+ self, config: GroupedConvKernelConfig, datatype: str = "fp16"
+ ) -> bool:
+ """Validate configuration against architecture constraints"""
+ if not self.arch_filter or not HAS_ARCH_FILTER:
+ return True
+
+ operator = self._get_operator_type(config.variant)
+
+ return self.arch_filter.is_kernel_valid(
+ datatype_a=datatype,
+ datatype_b=datatype,
+ datatype_c=datatype,
+ tile_m=config.tile.tile_m,
+ tile_n=config.tile.tile_n,
+ tile_k=config.tile.tile_k,
+ warp_m=config.tile.warp_m,
+ warp_n=config.tile.warp_n,
+ warp_k=1, # Grouped conv typically uses warp_k=1
+ warp_tile_m=config.tile.warp_tile_m,
+ warp_tile_n=config.tile.warp_tile_n,
+ warp_tile_k=config.tile.warp_tile_k,
+ pipeline=config.trait.pipeline,
+ epilogue=config.trait.epilogue,
+ scheduler=config.trait.scheduler,
+ operator=operator,
+ )
+
+ def generate_kernel(
+ self,
+ config: GroupedConvKernelConfig,
+ datatype: str,
+ variant: GroupedConvVariant = GroupedConvVariant.FORWARD,
+ ) -> Tuple[Path, Path]:
+ """Generate a single kernel file and dispatcher wrapper. Returns (kernel_path, wrapper_path)."""
+ kernel_gen = CKTileGroupedConvKernelGenerator(datatype, variant)
+ wrapper_gen = GroupedConvDispatcherWrapperGenerator(datatype, variant)
+
+ kernel_name = config.name(datatype)
+ filename = f"{kernel_name}.hpp"
+ filepath = self.output_dir / filename
+
+ # Generate kernel header
+ content = kernel_gen.generate(config)
+ filepath.write_text(content)
+ self.generated_files.append(filepath)
+
+ # Generate dispatcher wrapper
+ wrapper_content = wrapper_gen.generate(config, filepath, self.output_dir)
+ wrapper_path = self.wrapper_dir / f"dispatcher_wrapper_{kernel_name}.hpp"
+ wrapper_path.write_text(wrapper_content)
+ self.generated_wrappers.append(wrapper_path)
+
+ # Generate .cpp compilation unit for per-kernel parallel builds
+ cpp_filename = f"{kernel_name}.cpp"
+ cpp_filepath = self.output_dir / cpp_filename
+ cpp_content = f"""// SPDX-License-Identifier: MIT
+// Auto-generated compilation unit for: {kernel_name}
+// Enables per-kernel parallel compilation with make -j
+
+#include "{filename}"
+
+namespace ck_tile {{ namespace generated {{
+ volatile bool _{kernel_name.replace("-", "_")}_loaded = true;
+}} }}
+"""
+ cpp_filepath.write_text(cpp_content)
+
+ return filepath, wrapper_path
+
+ def _generate_single_kernel(self, item: _GenItem):
+ """Generate one kernel (used by parallel_generate). Returns (kernel_path, wrapper_path) or raises."""
+ kernel_path, wrapper_path = self.generate_kernel(
+ item.config, item.datatype, item.variant
+ )
+ log.info(
+ "Generated kernel %d/%d: %s",
+ item.idx,
+ item.total,
+ item.config.name(item.datatype),
+ )
+ return (kernel_path, wrapper_path)
+
+ def generate_all(
+ self,
+ configs: Optional[List[GroupedConvKernelConfig]] = None,
+ datatypes: Optional[List[str]] = None,
+ parallel: bool = True,
+ ) -> dict:
+ """Generate all kernel files (optionally in parallel).
+
+ Configs are filtered using architecture validation before generation.
+ Returns dict with keys: kernels, wrappers, failed.
+ """
+ if configs is None:
+ configs = self._get_configs()
+ if datatypes is None:
+ datatypes = [self.datatype]
+
+ results = {"kernels": [], "wrappers": [], "failed": []}
+
+ # Filter configs using arch validation
+ valid_tasks = []
+ rejected_count = 0
+
+ for datatype in datatypes:
+ for config in configs:
+ if self.is_config_valid(config, datatype):
+ valid_tasks.append((config, datatype, config.variant))
+ else:
+ rejected_count += 1
+ log.debug(
+ f"Rejected config for {self.gpu_target}: "
+ f"{config.tile.tile_m}x{config.tile.tile_n}x{config.tile.tile_k} "
+ f"variant={config.variant.value}"
+ )
+
+ if rejected_count > 0:
+ log.info(
+ f"Filtered {rejected_count} configs for {self.gpu_target}, "
+ f"{len(valid_tasks)} remaining"
+ )
+
+ total = len(valid_tasks)
+ items = [
+ _GenItem(i, total, config, datatype, variant)
+ for i, (config, datatype, variant) in enumerate(valid_tasks)
+ ]
+
+ def _safe_generate(item: _GenItem):
+ """Wrapper that catches exceptions for failure tracking."""
+ try:
+ k, w = self._generate_single_kernel(item)
+ return ("ok", k, w, None)
+ except Exception as e:
+ return ("fail", None, None, str(e))
+
+ raw = parallel_generate(
+ _safe_generate, items, parallel=parallel and len(items) > 1
+ )
+ for r in raw:
+ if r[0] == "ok":
+ results["kernels"].append(r[1])
+ results["wrappers"].append(r[2])
+ else:
+ results["failed"].append(r[3])
+ log.error("Failed: %s", r[3])
+
+ # Generate include_all_*.hpp headers for Python ctypes libraries
+ if results["wrappers"]:
+ self._generate_include_all_headers()
+
+ return results
+
+ def _generate_include_all_headers(self):
+ """Generate include_all_grouped_conv_*.hpp headers and registration header"""
+ # Scan output directory for ALL kernel files (not just this run's generated_files)
+ # This handles the case where fwd and bwd kernels are generated in separate make targets
+ fwd_headers = []
+ bwd_data_headers = []
+ bwd_weight_headers = []
+ fwd_kernels = []
+ bwd_data_kernels = []
+ bwd_weight_kernels = []
+
+ for filepath in self.output_dir.glob("grouped_conv_*.hpp"):
+ name = filepath.name
+ kernel_name = name[:-4]
+ if name.startswith("grouped_conv_fwd_"):
+ fwd_headers.append(name)
+ fwd_kernels.append(kernel_name)
+ elif name.startswith(("grouped_conv_bwd_data_", "grouped_conv_bwdd_")):
+ bwd_data_headers.append(name)
+ bwd_data_kernels.append(kernel_name)
+ elif name.startswith(("grouped_conv_bwd_weight_", "grouped_conv_bwdw_")):
+ bwd_weight_headers.append(name)
+ bwd_weight_kernels.append(kernel_name)
+
+ headers_to_generate = [
+ ("include_all_grouped_conv_fwd_kernels.hpp", fwd_headers, "forward"),
+ (
+ "include_all_grouped_conv_bwd_data_kernels.hpp",
+ bwd_data_headers,
+ "backward data",
+ ),
+ (
+ "include_all_grouped_conv_bwd_weight_kernels.hpp",
+ bwd_weight_headers,
+ "backward weight",
+ ),
+ ]
+
+ for header_name, kernel_headers, variant_desc in headers_to_generate:
+ header_path = self.output_dir / header_name
+ includes = "\n".join(f'#include "{h}"' for h in sorted(kernel_headers))
+
+ # Pick the first kernel as the default Selected*Launcher
+ if kernel_headers:
+ first_kernel = sorted(kernel_headers)[0][:-4] # Remove .hpp
+ if variant_desc == "forward":
+ launcher_alias = (
+ f"using SelectedConvKernelLauncher = {first_kernel}_Launcher;"
+ )
+ elif variant_desc == "backward data":
+ launcher_alias = (
+ f"using SelectedConvBwdDataLauncher = {first_kernel}_Launcher;"
+ )
+ else: # backward weight
+ launcher_alias = f"using SelectedConvBwdWeightLauncher = {first_kernel}_Launcher;"
+ else:
+ launcher_alias = "// No kernels generated for this variant"
+
+ content = f"""// SPDX-License-Identifier: MIT
+// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
+// Auto-generated header for grouped conv {variant_desc} kernels
+#pragma once
+
+{includes}
+
+// Default launcher alias (uses first kernel)
+{launcher_alias}
+"""
+ header_path.write_text(content)
+ if kernel_headers:
+ log.info(f"Generated: {header_name} ({len(kernel_headers)} kernels)")
+
+ # Generate registration header (following GEMM pattern)
+ self._generate_registration_header(
+ fwd_kernels, bwd_data_kernels, bwd_weight_kernels
+ )
+
+ def _generate_registration_header(
+ self,
+ fwd_kernels: List[str],
+ bwd_data_kernels: List[str],
+ bwd_weight_kernels: List[str],
+ ):
+ """Generate master registration header for all grouped conv kernels"""
+ # Scan wrapper directory for ALL wrapper files
+ all_wrappers = []
+ for wrapper_path in self.wrapper_dir.glob(
+ "dispatcher_wrapper_grouped_conv_*.hpp"
+ ):
+ all_wrappers.append(wrapper_path.name)
+
+ wrapper_includes = "\n".join(f'#include "{w}"' for w in sorted(all_wrappers))
+
+ # Generate registration calls
+ fwd_registrations = "\n ".join(
+ f"registry.register_kernel(generated::make_{k}(gfx_arch), priority);"
+ for k in sorted(fwd_kernels)
+ )
+ bwd_data_registrations = "\n ".join(
+ f"registry.register_kernel(generated::make_{k}(gfx_arch), priority);"
+ for k in sorted(bwd_data_kernels)
+ )
+ bwd_weight_registrations = "\n ".join(
+ f"registry.register_kernel(generated::make_{k}(gfx_arch), priority);"
+ for k in sorted(bwd_weight_kernels)
+ )
+
+ content = f"""// SPDX-License-Identifier: MIT
+// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
+// Auto-generated master registration header for grouped conv kernels
+#pragma once
+
+#include "ck_tile/dispatcher.hpp"
+#include "ck_tile/dispatcher/grouped_conv_utils.hpp"
+
+{wrapper_includes}
+
+namespace ck_tile {{
+namespace dispatcher {{
+
+using Priority = GroupedConvRegistry::Priority;
+
+inline void register_all_grouped_conv_fwd_kernels(
+ const std::string& gfx_arch = "gfx942",
+ Priority priority = Priority::Normal)
+{{
+ auto& registry = GroupedConvRegistry::instance();
+ {fwd_registrations if fwd_registrations else "// No forward kernels"}
+}}
+
+inline void register_all_grouped_conv_bwd_data_kernels(
+ const std::string& gfx_arch = "gfx942",
+ Priority priority = Priority::Normal)
+{{
+ auto& registry = GroupedConvRegistry::instance();
+ {bwd_data_registrations if bwd_data_registrations else "// No backward data kernels"}
+}}
+
+inline void register_all_grouped_conv_bwd_weight_kernels(
+ const std::string& gfx_arch = "gfx942",
+ Priority priority = Priority::Normal)
+{{
+ auto& registry = GroupedConvRegistry::instance();
+ {bwd_weight_registrations if bwd_weight_registrations else "// No backward weight kernels"}
+}}
+
+inline void register_all_grouped_conv_kernels(
+ const std::string& gfx_arch = "gfx942",
+ Priority priority = Priority::Normal)
+{{
+ register_all_grouped_conv_fwd_kernels(gfx_arch, priority);
+ register_all_grouped_conv_bwd_data_kernels(gfx_arch, priority);
+ register_all_grouped_conv_bwd_weight_kernels(gfx_arch, priority);
+}}
+
+inline std::size_t get_grouped_conv_fwd_kernel_count() {{ return {len(fwd_kernels)}; }}
+inline std::size_t get_grouped_conv_bwd_data_kernel_count() {{ return {len(bwd_data_kernels)}; }}
+inline std::size_t get_grouped_conv_bwd_weight_kernel_count() {{ return {len(bwd_weight_kernels)}; }}
+inline std::size_t get_grouped_conv_kernel_count() {{ return {len(fwd_kernels) + len(bwd_data_kernels) + len(bwd_weight_kernels)}; }}
+
+}} // namespace dispatcher
+}} // namespace ck_tile
+"""
+ reg_path = self.wrapper_dir / "register_all_grouped_conv_kernels.hpp"
+ reg_path.write_text(content)
+ log.info(f"Generated registration header: {reg_path}")
+
+
+# ============================================================================
+# CLI
+# ============================================================================
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Unified Grouped Convolution Code Generator"
+ )
+ parser.add_argument(
+ "--output",
+ "-o",
+ type=Path,
+ default=Path("build/generated_kernels"),
+ help="Output directory",
+ )
+ parser.add_argument(
+ "--datatype",
+ "-d",
+ type=str,
+ nargs="+",
+ default=["fp16"],
+ choices=["fp16", "bf16", "fp32"],
+ help="Data types to generate",
+ )
+ parser.add_argument(
+ "--variant",
+ "-v",
+ type=str,
+ nargs="+",
+ default=["forward"],
+ choices=["forward", "bwd_data", "bwd_weight"],
+ help="Grouped convolution variants",
+ )
+ parser.add_argument(
+ "--ndim",
+ "-n",
+ type=int,
+ nargs="+",
+ default=[2],
+ choices=[1, 2, 3],
+ help="Spatial dimensions",
+ )
+ parser.add_argument(
+ "--arch",
+ "-a",
+ type=str,
+ default="gfx942",
+ choices=["gfx90a", "gfx942", "gfx950", "gfx1201"],
+ help="Target GPU architecture",
+ )
+ parser.add_argument("--verbose", action="store_true", help="Verbose output")
+ parser.add_argument(
+ "--list-configs",
+ action="store_true",
+ help="List configurations without generating",
+ )
+
+ # Individual kernel configuration (when not using predefined configs)
+ parser.add_argument("--tile-m", type=int, help="Block tile M dimension")
+ parser.add_argument("--tile-n", type=int, help="Block tile N dimension")
+ parser.add_argument("--tile-k", type=int, help="Block tile K dimension")
+ parser.add_argument("--warp-m", type=int, help="Wave distribution M")
+ parser.add_argument("--warp-n", type=int, help="Wave distribution N")
+ parser.add_argument("--warp-k", type=int, default=1, help="Wave distribution K")
+ parser.add_argument("--warp-tile-m", type=int, help="Warp tile M")
+ parser.add_argument("--warp-tile-n", type=int, help="Warp tile N")
+ parser.add_argument("--warp-tile-k", type=int, default=16, help="Warp tile K")
+ parser.add_argument(
+ "--pipeline",
+ type=str,
+ choices=["mem", "compv3", "compv4", "compv5"],
+ help="Pipeline type",
+ )
+ parser.add_argument(
+ "--scheduler",
+ type=str,
+ choices=["intrawave", "interwave"],
+ help="Scheduler type",
+ )
+ parser.add_argument(
+ "--epilogue",
+ type=str,
+ default="cshuffle",
+ choices=["cshuffle", "default"],
+ help="Epilogue type",
+ )
+ parser.add_argument("--pad-m", type=bool, default=True, help="Pad M dimension")
+ parser.add_argument("--pad-n", type=bool, default=True, help="Pad N dimension")
+ parser.add_argument("--pad-k", type=bool, default=True, help="Pad K dimension")
+ parser.add_argument("--vector-a", type=int, default=4, help="Vector size A")
+ parser.add_argument("--vector-b", type=int, default=8, help="Vector size B")
+ parser.add_argument("--vector-c", type=int, default=8, help="Vector size C")
+ parser.add_argument("--block-per-cu", type=int, default=1, help="Blocks per CU")
+ parser.add_argument("--num-wave-groups", type=int, default=1, help="Wave groups")
+ parser.add_argument(
+ "--num-groups-to-merge", type=int, default=1, help="Groups to merge"
+ )
+ parser.add_argument(
+ "--double-smem-buffer",
+ type=str,
+ default=None,
+ help="Double SMEM buffer (true/false)",
+ )
+
+ args = parser.parse_args()
+
+ if args.verbose:
+ logging.getLogger().setLevel(logging.DEBUG)
+
+ # Map variant strings to enums
+ variant_map = {
+ "forward": GroupedConvVariant.FORWARD,
+ "bwd_data": GroupedConvVariant.BACKWARD_DATA,
+ "bwd_weight": GroupedConvVariant.BACKWARD_WEIGHT,
+ }
+ requested_variants = [variant_map[v] for v in args.variant]
+
+ # Check if user specified custom configuration
+ custom_config = (
+ args.tile_m is not None or args.tile_n is not None or args.pipeline is not None
+ )
+
+ if custom_config:
+ # Build custom config from CLI arguments
+ tile = TileConfig(
+ tile_m=args.tile_m or 128,
+ tile_n=args.tile_n or 128,
+ tile_k=args.tile_k or 64,
+ warp_m=args.warp_m or 2,
+ warp_n=args.warp_n or 2,
+ warp_k=args.warp_k or 1,
+ warp_tile_m=args.warp_tile_m or 32,
+ warp_tile_n=args.warp_tile_n or 32,
+ warp_tile_k=args.warp_tile_k or 16,
+ )
+ pipeline = args.pipeline or "compv4"
+ # Determine double_smem_buffer: use CLI arg if given, else default based on pipeline
+ if args.double_smem_buffer is not None:
+ dsb = args.double_smem_buffer.lower() == "true"
+ else:
+ dsb = pipeline == "compv4" # compv4 requires double buffer
+
+ trait = GroupedConvTraitConfig(
+ pipeline=pipeline,
+ scheduler=args.scheduler or "intrawave",
+ epilogue=args.epilogue or "cshuffle",
+ pad_m=args.pad_m,
+ pad_n=args.pad_n,
+ pad_k=args.pad_k,
+ double_smem_buffer=dsb,
+ num_groups_to_merge=args.num_groups_to_merge,
+ )
+ config = GroupedConvKernelConfig(
+ tile=tile,
+ trait=trait,
+ variant=requested_variants[0]
+ if requested_variants
+ else GroupedConvVariant.FORWARD,
+ ndim_spatial=args.ndim[0] if args.ndim else 2,
+ arch=args.arch,
+ vector_size_a=args.vector_a,
+ vector_size_b=args.vector_b,
+ vector_size_c=args.vector_c,
+ block_per_cu=args.block_per_cu,
+ num_wave_groups=args.num_wave_groups,
+ )
+ filtered_configs = [config]
+ else:
+ # Get predefined configurations for target arch with requested variants and ndims
+ filtered_configs = get_default_configs(
+ arch=args.arch, variants=requested_variants, ndims=args.ndim
+ )
+
+ if args.list_configs:
+ print(f"Grouped convolution configurations for {args.arch}:")
+ print(f" Datatypes: {args.datatype}")
+ print(f" Variants: {args.variant}")
+ print(f" Spatial dims: {args.ndim}")
+ print(f"\nConfigurations ({len(filtered_configs)}):")
+ for cfg in filtered_configs:
+ print(f" - {cfg.name('fp16')}")
+ print(f" Tile: {cfg.tile.tile_m}x{cfg.tile.tile_n}x{cfg.tile.tile_k}")
+ print(f" Warp: {cfg.tile.warp_m}x{cfg.tile.warp_n}x{cfg.tile.warp_k}")
+ print(
+ f" WarpTile: {cfg.tile.warp_tile_m}x{cfg.tile.warp_tile_n}x{cfg.tile.warp_tile_k}"
+ )
+ print(
+ f" Pipeline: {cfg.trait.pipeline}, Epilogue: {cfg.trait.epilogue}, Scheduler: {cfg.trait.scheduler}"
+ )
+ print(
+ f" Padding: M={cfg.trait.pad_m}, N={cfg.trait.pad_n}, K={cfg.trait.pad_k}"
+ )
+ return
+
+ # Generate
+ codegen = UnifiedGroupedConvCodegen(
+ output_dir=args.output,
+ gpu_target=args.arch,
+ enable_arch_filter=True,
+ )
+ results = codegen.generate_all(
+ configs=filtered_configs, datatypes=args.datatype, parallel=True
+ )
+
+ print(
+ f"\nGenerated {len(results['kernels'])} grouped convolution kernel files "
+ f"for {args.arch} in {args.output}"
+ )
+ if results["failed"]:
+ print(f" Failed: {len(results['failed'])}")
+ for err in results["failed"][:5]:
+ print(f" - {err}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/dispatcher/examples/CMakeLists.txt b/dispatcher/examples/CMakeLists.txt
index bda8eb0372..ab094e90cf 100644
--- a/dispatcher/examples/CMakeLists.txt
+++ b/dispatcher/examples/CMakeLists.txt
@@ -187,7 +187,6 @@ function(add_gpu_example NAME SOURCE KERNEL_HEADER)
if(HEADER_NAME STREQUAL "register_all_kernels.hpp")
# Registration header - examples include it directly
target_compile_options(${NAME} PRIVATE
- -DGEMM_KERNEL_AVAILABLE=1
-mllvm -enable-noalias-to-md-conversion=0
-Wno-undefined-func-template
-Wno-float-equal
@@ -315,6 +314,7 @@ function(add_declarative_gpu_example NAME SOURCE)
target_include_directories(${NAME} PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/../../include
${CMAKE_CURRENT_SOURCE_DIR}/../include
+ ${CMAKE_CURRENT_SOURCE_DIR}/../..
${EXAMPLE_KERNEL_DIR}
${EXAMPLE_KERNEL_DIR}/dispatcher_wrappers
)
@@ -322,7 +322,6 @@ function(add_declarative_gpu_example NAME SOURCE)
# Force-include the generated registration header
target_compile_options(${NAME} PRIVATE
-include ${EXAMPLE_HEADER}
- -DGEMM_KERNEL_AVAILABLE=1
-mllvm -enable-noalias-to-md-conversion=0
-Wno-undefined-func-template
-Wno-float-equal
@@ -345,6 +344,7 @@ add_declarative_gpu_example(gemm_03_benchmark_validation gemm/cpp/03_benchmark_v
add_declarative_gpu_example(gemm_04_heuristics gemm/cpp/04_heuristics.cpp)
add_declarative_gpu_example(gemm_05_json_export gemm/cpp/05_json_export.cpp)
add_declarative_gpu_example(gemm_06_multi_registry gemm/cpp/06_multi_registry.cpp)
+add_declarative_gpu_example(gemm_07_gfx950_minimal gemm/cpp/07_gfx950_minimal.cpp)
# ML Heuristic example -- requires LightGBM shared library
# Derive site-packages from active Python interpreter (respects virtualenvs)
@@ -443,19 +443,79 @@ if(hip_FOUND)
endif()
add_dependencies(dispatcher_gemm_lib generate_gemm_fallback_kernel)
+# =============================================================================
+# Grouped Convolution C++ Examples
+# =============================================================================
+
+add_declarative_gpu_example(grouped_conv_01_basic grouped_conv/cpp/01_basic_grouped_conv.cpp)
+add_declarative_gpu_example(grouped_conv_02_all_dirs grouped_conv/cpp/02_all_directions.cpp)
+add_declarative_gpu_example(grouped_conv_03_bench_val grouped_conv/cpp/03_benchmark_validation.cpp)
+add_declarative_gpu_example(grouped_conv_04_registry_json grouped_conv/cpp/04_registry_json.cpp)
+add_declarative_gpu_example(grouped_conv_05_bwd_data grouped_conv/cpp/05_bwd_data.cpp)
+add_declarative_gpu_example(grouped_conv_06_bwd_weight grouped_conv/cpp/06_bwd_weight.cpp)
+add_declarative_gpu_example(grouped_conv_07_benchmark grouped_conv/cpp/07_multi_tile_benchmark.cpp)
+
+# =============================================================================
+# Grouped Convolution Python Library - Multi-Kernel (fwd/bwd_data/bwd_weight x 2D/3D)
+# =============================================================================
+
+# Kernel output directory for the Python conv library
+set(CONV_FALLBACK_KERNEL_DIR "${CMAKE_CURRENT_BINARY_DIR}/conv_python_fallback")
+set(CONV_DISPATCH_HEADER "${CONV_FALLBACK_KERNEL_DIR}/conv_python_dispatch.hpp")
+
+# Generate ALL conv kernels (fwd/bwd_data/bwd_weight x 2D/3D x multiple tile configs)
+# then create the dispatch header with 2D/3D aliases
+add_custom_command(
+ OUTPUT ${CONV_DISPATCH_HEADER}
+ COMMAND ${CMAKE_COMMAND} -E make_directory ${CONV_FALLBACK_KERNEL_DIR}
+ COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../codegen/unified_grouped_conv_codegen.py
+ --variant forward bwd_data bwd_weight --ndim 2 3
+ --datatype fp16 --arch ${GPU_TARGET}
+ --output ${CONV_FALLBACK_KERNEL_DIR}
+ COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../scripts/generate_conv_dispatch_header.py
+ --kernel-dir ${CONV_FALLBACK_KERNEL_DIR}
+ --output ${CONV_DISPATCH_HEADER}
+ COMMENT "Generating conv kernels (fwd/bwd_data/bwd_weight x 2D/3D) for Python library..."
+ VERBATIM
+)
+
+add_custom_target(generate_conv_fallback_kernels DEPENDS ${CONV_DISPATCH_HEADER})
+
+# Conv dynamic library for Python (all 6 kernel variants)
+add_library(dispatcher_conv_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/../bindings/ctypes/conv_ctypes_lib.cpp)
+target_link_libraries(dispatcher_conv_lib PRIVATE ck_tile_dispatcher)
+target_include_directories(dispatcher_conv_lib PRIVATE
+ ${CMAKE_CURRENT_SOURCE_DIR}/../../include
+ ${CMAKE_CURRENT_SOURCE_DIR}/../include
+ ${CONV_FALLBACK_KERNEL_DIR}
+)
+target_compile_options(dispatcher_conv_lib PRIVATE
+ -include ${CONV_DISPATCH_HEADER}
+ -DGFX_ARCH="${GPU_TARGET}"
+ -mllvm -enable-noalias-to-md-conversion=0
+ -Wno-undefined-func-template
+ -Wno-float-equal
+ --offload-compress
+)
+if(hip_FOUND)
+ target_link_libraries(dispatcher_conv_lib PRIVATE hip::device hip::host)
+endif()
+add_dependencies(dispatcher_conv_lib generate_conv_fallback_kernels)
+
message(STATUS "GEMM examples configured - kernels will be generated during 'make'")
+message(STATUS "Grouped Conv examples configured - kernels will be generated during 'make'")
# Convenience target to build all Python ctypes libraries
add_custom_target(python_libs
- DEPENDS dispatcher_gemm_lib
- COMMENT "Building Python ctypes libraries (GEMM)"
+ DEPENDS dispatcher_gemm_lib dispatcher_conv_lib
+ COMMENT "Building Python ctypes libraries (GEMM + Conv)"
)
# =============================================================================
# Per-Architecture Kernel Generation Targets
# =============================================================================
-set(SUPPORTED_GPU_ARCHS gfx942 gfx90a gfx1100 gfx1030)
+set(SUPPORTED_GPU_ARCHS gfx942 gfx950 gfx90a gfx1100 gfx1030)
foreach(ARCH ${SUPPORTED_GPU_ARCHS})
# GEMM kernels for this arch
diff --git a/dispatcher/examples/README.md b/dispatcher/examples/README.md
index fdee9c3583..24bea821ba 100644
--- a/dispatcher/examples/README.md
+++ b/dispatcher/examples/README.md
@@ -1,8 +1,6 @@
# CK Tile Dispatcher Examples
-Comprehensive examples for GEMM operations with GPU execution.
-
-> **Note**: Convolution examples have been moved to `ck-2/conv_archive/` for reference.
+Comprehensive examples for GEMM and Grouped Convolution operations with GPU execution.
---
@@ -60,11 +58,11 @@ python3 examples/gemm/python/08_heuristics.py
```
examples/
-├── gemm/
-│ ├── cpp/ # 6 C++ GEMM examples
-│ └── python/ # 11 Python GEMM examples
-│
-└── README.md
+|---- gemm/
+| |---- cpp/ # 6 C++ GEMM examples
+| +---- python/ # 11 Python GEMM examples
+|
++---- README.md
```
---
@@ -201,10 +199,31 @@ rocminfo | grep "Name:"
---
-## Archived Examples
+## Grouped Convolution
-Convolution examples have been archived to `ck-2/conv_archive/dispatcher/`:
-- `examples/conv/cpp/` - 11 C++ convolution examples
-- `examples/conv/python/` - 14 Python convolution examples
+Grouped convolution support has been re-introduced with a unified infrastructure shared with GEMM.
-See the archive for convolution functionality reference.
+### Infrastructure
+
+The grouped convolution code generation, utilities, and build scripts are available:
+
+| Component | Location |
+|-----------|----------|
+| C++ Headers | `include/ck_tile/dispatcher/grouped_conv_*.hpp` |
+| Python Codegen | `codegen/unified_grouped_conv_codegen.py` |
+| Python Utils | `python/grouped_conv_utils.py` |
+| Build Script | `scripts/compile_grouped_conv_examples.py` |
+
+### Building Grouped Conv Kernels
+
+```bash
+# Generate grouped conv kernels
+python3 codegen/unified_grouped_conv_codegen.py \
+ --output-dir build/generated_kernels \
+ --datatype fp16 --variant forward --ndim-spatial 2
+
+# Compile a grouped conv example
+python3 scripts/compile_grouped_conv_examples.py my_grouped_conv_example.cpp
+```
+
+See the [main README](../README.md#grouped-convolution-support) for more details.
diff --git a/dispatcher/examples/gemm/cpp/02_multi_size.cpp b/dispatcher/examples/gemm/cpp/02_multi_size.cpp
index 5e620209f4..ffd2858be4 100644
--- a/dispatcher/examples/gemm/cpp/02_multi_size.cpp
+++ b/dispatcher/examples/gemm/cpp/02_multi_size.cpp
@@ -21,9 +21,9 @@
* - pipeline: "compv3" -> 1 option (compv4 requires special handling)
* - scheduler: "intrawave" -> 1 option
*
- * Raw expansion: 3 × 2 = 6 configs, but arch filter validates each:
- * - tile_m must be divisible by (warp_m × warp_tile_m)
- * - tile_n must be divisible by (warp_n × warp_tile_n)
+ * Raw expansion: 3 x 2 = 6 configs, but arch filter validates each:
+ * - tile_m must be divisible by (warp_m x warp_tile_m)
+ * - tile_n must be divisible by (warp_n x warp_tile_n)
* - Some wave/warp combos invalid: (4,1,1)+(32,32,16), (1,4,1)+(32,32,16)
* Result: 4 valid wildcard kernels + 1 explicit = 5 total
*
@@ -70,13 +70,13 @@ DECL_KERNEL_SET(multi_size_kernels,
.add(Signature().dtype("fp16").layout("rcr"),
Algorithm()
.tile(64, 64, 64)
- .wave(ANY_INT, ANY_INT, 1) // ANY_INT → (1,4,1), (2,2,1), (4,1,1)
- .warp(-1, -1, -1) // -1 same as ANY_INT → (16,16,32), (32,32,16)
- .pipeline("*") // "*" → valid pipelines
- .scheduler("*") // "*" → valid schedulers
+ .wave(ANY_INT, ANY_INT, 1) // ANY_INT -> (1,4,1), (2,2,1), (4,1,1)
+ .warp(-1, -1, -1) // -1 same as ANY_INT -> (16,16,32), (32,32,16)
+ .pipeline("*") // "*" -> valid pipelines
+ .scheduler("*") // "*" -> valid schedulers
.epilogue("cshuffle"),
"gfx942"));
-// Raw: 3×2=6, arch filter removes 2 invalid → 4 valid kernels
+// Raw: 3x2=6, arch filter removes 2 invalid -> 4 valid kernels
// =============================================================================
// MAIN
@@ -116,8 +116,8 @@ int main(int argc, char* argv[])
.pipeline("*") -> expands to valid pipelines = 1
.scheduler("*") -> expands to valid schedulers = 1
- Expanded: 3 × 2 = 6 configs, but arch filter validates each:
- - wave×warp must divide tile: (4,1,1)×(32,32,16) invalid for 64x64
+ Expanded: 3 x 2 = 6 configs, but arch filter validates each:
+ - wave x warp must divide tile: (4,1,1)x(32,32,16) invalid for 64x64
- Result: 4 valid kernels from wildcard + 1 explicit = 5 total
)";
diff --git a/dispatcher/examples/gemm/cpp/07_gfx950_minimal.cpp b/dispatcher/examples/gemm/cpp/07_gfx950_minimal.cpp
new file mode 100644
index 0000000000..7e62ad2e4f
--- /dev/null
+++ b/dispatcher/examples/gemm/cpp/07_gfx950_minimal.cpp
@@ -0,0 +1,191 @@
+// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
+// SPDX-License-Identifier: MIT
+
+/**
+ * Example 07: Minimal gfx950 (CDNA4 / MI350) GEMM
+ *
+ * Demonstrates the dispatcher working with gfx950-specific kernels:
+ *
+ * - fp16 GEMM with standard tile configs
+ * - fp8 GEMM with gfx950-extended warp tiles (16x16x128)
+ * - 160KB LDS: gfx950 doubles the LDS from 64KB to 160KB
+ *
+ * Build: cd dispatcher/build && cmake .. -DGPU_TARGETS=gfx950 && make gemm_07_gfx950_minimal
+ */
+
+#include
+#include
+#include
+#include
+
+#include "ck_tile/dispatcher.hpp"
+#include "ck_tile/dispatcher/kernel_decl.hpp"
+#include "ck_tile/dispatcher/example_args.hpp"
+
+using namespace ck_tile::dispatcher;
+using namespace ck_tile::dispatcher::backends;
+using namespace ck_tile::dispatcher::utils;
+using Signature = decl::Signature;
+using Algorithm = decl::Algorithm;
+
+// =============================================================================
+// gfx950-targeted kernel declarations
+// =============================================================================
+
+DECL_KERNEL_SET(gfx950_gemm_kernels,
+
+ // fp16 128x128x32 -- bread-and-butter config, works on all CDNA
+ .add(Signature().dtype("fp16").layout("rcr"),
+ Algorithm()
+ .tile(128, 128, 32)
+ .wave(2, 2, 1)
+ .warp(32, 32, 16)
+ .pipeline("compv3")
+ .scheduler("intrawave")
+ .epilogue("cshuffle"),
+ "gfx950")
+
+ // fp16 128x128x64 -- deeper K tile using more LDS
+ // LDS usage: 128*64*2 + 128*64*2 = 32768 bytes (fits 64KB, gfx950 has 160KB)
+ .add(Signature().dtype("fp16").layout("rcr"),
+ Algorithm()
+ .tile(128, 128, 64)
+ .wave(2, 2, 1)
+ .warp(32, 32, 16)
+ .pipeline("compv3")
+ .scheduler("intrawave")
+ .epilogue("cshuffle"),
+ "gfx950")
+
+ // fp16 64x64x32 -- small-tile variant for small problems
+ .add(Signature().dtype("fp16").layout("rcr"),
+ Algorithm()
+ .tile(64, 64, 32)
+ .wave(2, 2, 1)
+ .warp(16, 16, 32)
+ .pipeline("compv3")
+ .scheduler("intrawave")
+ .epilogue("cshuffle"),
+ "gfx950"));
+
+// =============================================================================
+// MAIN
+// =============================================================================
+
+int main(int argc, char* argv[])
+{
+ ExampleArgs args("Example 07: gfx950 Minimal GEMM",
+ "Demonstrates gfx950 (CDNA4 / MI350) dispatcher");
+ args.add_flag("--list", "List registered kernels");
+ args.add_flag("--list-verbose", "List registered kernels with full details");
+ args.add_option("--M", "1024", "Problem M dimension");
+ args.add_option("--N", "1024", "Problem N dimension");
+ args.add_option("--K", "1024", "Problem K dimension");
+ args.add_option("--arch", "gfx950", "GPU architecture (default: gfx950)");
+
+ if(!args.parse(argc, argv))
+ return 0;
+
+ std::string gfx_arch = args.get("--arch", "gfx950");
+
+ print_header("Example 07: gfx950 (CDNA4) Minimal GEMM");
+
+ // =========================================================================
+ // Architecture info
+ // =========================================================================
+ std::cout << "\ngfx950 (CDNA4 / MI350) highlights:\n";
+ std::cout << " - 160KB LDS (up from 64KB on gfx942)\n";
+ std::cout << " - Extended FP8 warp tiles: 16x16x128, 32x32x64\n";
+ std::cout << " - Packed FP4 support (pk_fp4)\n";
+ std::cout << " - Same warp configs as gfx942: [1,4,1], [2,2,1], [4,1,1]\n\n";
+
+ // =========================================================================
+ // Register kernels
+ // =========================================================================
+ std::cout << "Registering kernels for " << gfx_arch << "...\n";
+
+ Registry registry;
+ registry.set_name("gfx950_gemm");
+ REGISTER_GENERATED_KERNELS(registry, gfx_arch);
+
+ std::cout << " Registered " << registry.size() << " kernel(s)\n";
+
+ if(args.has("--list") || args.has("--list-verbose"))
+ {
+ std::cout << "\n";
+ print_registered_kernels(registry, std::cout, args.has("--list-verbose"));
+ return 0;
+ }
+
+ if(registry.size() == 0)
+ {
+ std::cerr << "ERROR: No kernels registered for " << gfx_arch << "!\n";
+ std::cerr << " Did you build with -DGPU_TARGETS=gfx950?\n";
+ return 1;
+ }
+
+ // =========================================================================
+ // Create Dispatcher
+ // =========================================================================
+ Dispatcher dispatcher(®istry);
+
+ // =========================================================================
+ // Setup Problem
+ // =========================================================================
+ const int M = args.get_int("--M", 1024);
+ const int N = args.get_int("--N", 1024);
+ const int K = args.get_int("--K", 1024);
+
+ std::cout << "\nProblem: " << M << " x " << N << " x " << K << "\n";
+
+ Problem problem(M, N, K);
+
+ using DataType = ck_tile::fp16_t;
+ GpuBuffer a_dev(M * K);
+ GpuBuffer b_dev(K * N);
+ GpuBuffer c_dev(M * N);
+
+ std::vector a_host(M * K, DataType(1.0f));
+ std::vector b_host(K * N, DataType(1.0f));
+ a_dev.copy_from_host(a_host.data());
+ b_dev.copy_from_host(b_host.data());
+ c_dev.zero();
+
+ // =========================================================================
+ // Select and Run
+ // =========================================================================
+ auto selected = dispatcher.select_kernel(problem);
+ if(!selected)
+ {
+ std::cerr << "ERROR: No suitable kernel found for " << M << "x" << N << "x" << K << "\n";
+ return 1;
+ }
+ std::cout << " Selected: " << selected->get_name() << "\n";
+
+ float time_ms = dispatcher.run(a_dev.get(), b_dev.get(), c_dev.get(), problem, nullptr);
+ std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n";
+ std::cout << " TFLOPS: " << std::setprecision(2) << calculate_tflops(M, N, K, time_ms) << "\n";
+
+ // =========================================================================
+ // Verify
+ // =========================================================================
+ std::cout << "\nVerification:\n";
+ std::vector c_host(M * N);
+ c_dev.copy_to_host(c_host.data());
+
+ const float expected = static_cast(K);
+ int errors = 0;
+ for(int i = 0; i < std::min(M * N, 1024); ++i)
+ {
+ if(std::abs(static_cast(c_host[i]) - expected) > 0.01f * expected + 1.0f)
+ ++errors;
+ }
+
+ bool passed = (errors == 0);
+ std::cout << " Expected value: " << expected << "\n";
+ std::cout << " Errors (first 1024 elements): " << errors << "\n";
+ std::cout << " Status: " << (passed ? "PASS" : "FAIL") << "\n";
+
+ print_separator();
+ return passed ? 0 : 1;
+}
diff --git a/dispatcher/examples/gemm/cpp/README.md b/dispatcher/examples/gemm/cpp/README.md
index 1d81a90a0e..79d60d1198 100644
--- a/dispatcher/examples/gemm/cpp/README.md
+++ b/dispatcher/examples/gemm/cpp/README.md
@@ -29,14 +29,14 @@ cd examples
## Examples
-| Example | Description | Complexity |
-|---------|-------------|------------|
-| [01_basic_gemm.cpp](01_basic_gemm.cpp) | Basic GEMM with declarative API, autofill, autocorrect | ★☆☆☆☆ |
-| [02_multi_size.cpp](02_multi_size.cpp) | Wildcard expansion for multiple configurations | ★★☆☆☆ |
-| [03_benchmark_validation.cpp](03_benchmark_validation.cpp) | Performance benchmarking with CPU reference validation | ★★☆☆☆ |
-| [04_heuristics.cpp](04_heuristics.cpp) | Heuristic-based kernel selection | ★★★☆☆ |
-| [05_json_export.cpp](05_json_export.cpp) | Registry JSON export for external tools | ★★☆☆☆ |
-| [06_multi_registry.cpp](06_multi_registry.cpp) | Multiple registries with named kernel sets | ★★★☆☆ |
+| Example | Description |
+|---------|-------------|
+| [01_basic_gemm.cpp](01_basic_gemm.cpp) | Basic GEMM with declarative API, autofill, autocorrect |
+| [02_multi_size.cpp](02_multi_size.cpp) | Wildcard expansion for multiple configurations |
+| [03_benchmark_validation.cpp](03_benchmark_validation.cpp) | Performance benchmarking with CPU reference validation |
+| [04_heuristics.cpp](04_heuristics.cpp) | Heuristic-based kernel selection |
+| [05_json_export.cpp](05_json_export.cpp) | Registry JSON export for external tools |
+| [06_multi_registry.cpp](06_multi_registry.cpp) | Multiple registries with named kernel sets |
## Example Details
@@ -225,5 +225,5 @@ DECL_KERNEL_SET(my_kernels,
## Related Documentation
- [Python GEMM Examples](../python/README.md)
-- [Convolution Examples](../../conv/cpp/README.md)
+- [C++ Headers (GEMM + Grouped Conv)](../../../include/ck_tile/dispatcher/README.md)
- [Main Dispatcher README](../../../README.md)
diff --git a/dispatcher/examples/gemm/python/01_basic_gemm.py b/dispatcher/examples/gemm/python/01_basic_gemm.py
index 93a78d24d1..8c23da89e2 100644
--- a/dispatcher/examples/gemm/python/01_basic_gemm.py
+++ b/dispatcher/examples/gemm/python/01_basic_gemm.py
@@ -7,41 +7,37 @@
Example 01: Basic GEMM with Multiple Kernels
Demonstrates:
-1. Declaring multiple kernel configurations
-2. Printing all registered kernels
-3. Running each kernel and validating output
+1. Building a Registry with multiple kernel configurations
+2. Parallel JIT compilation via registry.build()
+3. Running each kernel and validating output against NumPy reference
4. Comparing performance across kernels
-Complexity: ★★☆☆☆
-
Usage:
python3 01_basic_gemm.py
- python3 01_basic_gemm.py --help
python3 01_basic_gemm.py --dtype bf16
python3 01_basic_gemm.py --size 2048
+ python3 01_basic_gemm.py --num-kernels 4
+ python3 01_basic_gemm.py --workers 4
"""
import sys
+import time
import argparse
from pathlib import Path
from dataclasses import dataclass
-from typing import List
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from ctypes_utils import (
KernelConfig,
- setup_gemm_dispatcher,
- cleanup_gemm,
- reset_for_example,
+ Registry,
+ detect_gpu_arch,
)
@dataclass
class KernelSpec:
- """Specification for a kernel configuration"""
-
name: str
tile_m: int
tile_n: int
@@ -50,80 +46,37 @@ class KernelSpec:
scheduler: str = "intrawave"
-# Define multiple kernel configurations to test (50+ kernels)
KERNEL_SPECS = [
- # Small tiles - compv3
+ # Small tiles
KernelSpec("small_64x64_k32", 64, 64, 32, "compv3"),
KernelSpec("small_64x64_k64", 64, 64, 64, "compv3"),
- # Small tiles - compv4
KernelSpec("small_64x64_v4_k32", 64, 64, 32, "compv4"),
- KernelSpec("small_64x64_v4_k64", 64, 64, 64, "compv4"),
- # Medium tiles - compv3
+ # Medium tiles
KernelSpec("med_128x128_k32", 128, 128, 32, "compv3"),
KernelSpec("med_128x128_k64", 128, 128, 64, "compv3"),
- KernelSpec("med_128x128_k128", 128, 128, 128, "compv3"),
- # Medium tiles - compv4
KernelSpec("med_128x128_v4_k32", 128, 128, 32, "compv4"),
KernelSpec("med_128x128_v4_k64", 128, 128, 64, "compv4"),
- KernelSpec("med_128x128_v4_k128", 128, 128, 128, "compv4"),
- # Rectangular tiles - compv3
+ # Rectangular tiles
KernelSpec("rect_64x128_k32", 64, 128, 32, "compv3"),
KernelSpec("rect_64x128_k64", 64, 128, 64, "compv3"),
KernelSpec("rect_128x64_k32", 128, 64, 32, "compv3"),
KernelSpec("rect_128x64_k64", 128, 64, 64, "compv3"),
- # Rectangular tiles - compv4
KernelSpec("rect_64x128_v4_k32", 64, 128, 32, "compv4"),
- KernelSpec("rect_64x128_v4_k64", 64, 128, 64, "compv4"),
KernelSpec("rect_128x64_v4_k32", 128, 64, 32, "compv4"),
- KernelSpec("rect_128x64_v4_k64", 128, 64, 64, "compv4"),
- # Large tiles - compv3
+ # Large tiles
KernelSpec("large_256x128_k32", 256, 128, 32, "compv3"),
- KernelSpec("large_256x128_k64", 256, 128, 64, "compv3"),
KernelSpec("large_128x256_k32", 128, 256, 32, "compv3"),
- KernelSpec("large_128x256_k64", 128, 256, 64, "compv3"),
KernelSpec("large_256x256_k32", 256, 256, 32, "compv3"),
- KernelSpec("large_256x256_k64", 256, 256, 64, "compv3"),
- # Large tiles - compv4
KernelSpec("large_256x128_v4_k32", 256, 128, 32, "compv4"),
- KernelSpec("large_256x128_v4_k64", 256, 128, 64, "compv4"),
- KernelSpec("large_128x256_v4_k32", 128, 256, 32, "compv4"),
- KernelSpec("large_128x256_v4_k64", 128, 256, 64, "compv4"),
KernelSpec("large_256x256_v4_k32", 256, 256, 32, "compv4"),
- KernelSpec("large_256x256_v4_k64", 256, 256, 64, "compv4"),
- # Interwave scheduler variants
- KernelSpec("int_64x64_k32", 64, 64, 32, "compv3", "interwave"),
+ # Interwave scheduler
KernelSpec("int_128x128_k32", 128, 128, 32, "compv3", "interwave"),
- KernelSpec("int_128x128_k64", 128, 128, 64, "compv3", "interwave"),
KernelSpec("int_256x128_k32", 256, 128, 32, "compv3", "interwave"),
- # More tile_k variations - compv3
- KernelSpec("med_128x128_k16", 128, 128, 16, "compv3"),
- KernelSpec("rect_64x128_k16", 64, 128, 16, "compv3"),
- KernelSpec("rect_128x64_k16", 128, 64, 16, "compv3"),
- # More tile_k variations - compv4
- KernelSpec("med_128x128_v4_k16", 128, 128, 16, "compv4"),
- KernelSpec("rect_64x128_v4_k16", 64, 128, 16, "compv4"),
- KernelSpec("rect_128x64_v4_k16", 128, 64, 16, "compv4"),
- # Additional rectangular
- KernelSpec("rect_32x64_k32", 32, 64, 32, "compv3"),
- KernelSpec("rect_64x32_k32", 64, 32, 32, "compv3"),
- KernelSpec("rect_32x128_k32", 32, 128, 32, "compv3"),
- KernelSpec("rect_128x32_k32", 128, 32, 32, "compv3"),
- # Additional compv4 variants
- KernelSpec("rect_32x64_v4_k32", 32, 64, 32, "compv4"),
- KernelSpec("rect_64x32_v4_k32", 64, 32, 32, "compv4"),
- KernelSpec("rect_32x128_v4_k32", 32, 128, 32, "compv4"),
- KernelSpec("rect_128x32_v4_k32", 128, 32, 32, "compv4"),
]
-def create_kernel_config(spec: KernelSpec, dtype: str, arch: str) -> KernelConfig:
- """Create a KernelConfig from a spec"""
- # Adjust warp tiles based on tile size
- if spec.tile_m <= 64:
- warp_m, warp_n = 16, 16
- else:
- warp_m, warp_n = 32, 32
-
+def spec_to_config(spec: KernelSpec, dtype: str, arch: str) -> KernelConfig:
+ warp_m, warp_n = (16, 16) if spec.tile_m <= 64 else (32, 32)
return KernelConfig(
dtype_a=dtype,
dtype_b=dtype,
@@ -148,180 +101,118 @@ def create_kernel_config(spec: KernelSpec, dtype: str, arch: str) -> KernelConfi
)
-def print_kernel_table(specs: List[KernelSpec], dtype: str):
- """Print a formatted table of kernel configurations"""
- print("\n" + "=" * 70)
- print(f" DECLARED KERNEL CONFIGURATIONS ({len(specs)} kernels)")
- print("=" * 70)
- print(f"\n {'#':<3} {'Name':<18} {'Tile':<14} {'Pipeline':<10} {'Scheduler':<12}")
- print(" " + "-" * 68)
-
- for i, spec in enumerate(specs, 1):
- tile = f"{spec.tile_m}x{spec.tile_n}x{spec.tile_k}"
- print(
- f" {i:<3} {spec.name:<18} {tile:<14} {spec.pipeline:<10} {spec.scheduler:<12}"
- )
-
- print(" " + "-" * 68)
- print(f" Data type: {dtype}")
-
-
def main():
- parser = argparse.ArgumentParser(
- description="Basic GEMM Example with Multiple Kernels",
- formatter_class=argparse.RawDescriptionHelpFormatter,
- epilog="""
-Examples:
- python3 01_basic_gemm.py # Default FP16 with 4 kernels
- python3 01_basic_gemm.py --dtype bf16 # BF16 mode
- python3 01_basic_gemm.py --size 2048 # Larger problem size
- python3 01_basic_gemm.py --num-kernels 2 # Test only 2 kernels
- """,
- )
+ parser = argparse.ArgumentParser(description="Basic GEMM with Multiple Kernels")
+ parser.add_argument("--dtype", default="fp16", choices=["fp16", "bf16"])
+ parser.add_argument("--arch", default=detect_gpu_arch())
+ parser.add_argument("--size", type=int, default=512, help="Problem size MxNxK")
+ parser.add_argument("--num-kernels", type=int, default=0, help="0 = all")
parser.add_argument(
- "--dtype",
- default="fp16",
- choices=["fp16", "bf16", "fp32"],
- help="Data type (default: fp16)",
- )
- parser.add_argument(
- "--arch",
- default="gfx942",
- help="Target architecture (default: gfx942)",
- )
- parser.add_argument(
- "--size",
- type=int,
- default=512,
- help="Problem size MxNxK (default: 512)",
- )
- parser.add_argument(
- "--num-kernels",
- type=int,
- default=0,
- help="Number of kernels to test (0 = all)",
+ "--workers", type=int, default=0, help="Max parallel JIT workers (0 = auto)"
)
args = parser.parse_args()
- reset_for_example()
-
print("=" * 70)
print("Example 01: Basic GEMM with Multiple Kernels")
print("=" * 70)
- # Select kernels to test
specs = KERNEL_SPECS[: args.num_kernels] if args.num_kernels > 0 else KERNEL_SPECS
- # =========================================================================
- # Step 1: Print all kernel configurations
- # =========================================================================
- print_kernel_table(specs, args.dtype)
-
- # =========================================================================
- # Step 2: Setup and test each kernel
- # =========================================================================
- print("\n" + "=" * 70)
- print(" RUNNING KERNELS")
- print("=" * 70)
-
- np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32
- M, N, K = args.size, args.size, args.size
-
- results = []
-
- print(f"\n Problem size: {M}x{N}x{K}\n")
+ # Step 1: Build registry
print(
- f" {'#':<3} {'Name':<18} {'Tile':<14} {'Time (ms)':>10} {'TFLOPS':>10} {'Max Err':>10} {'Status':<8}"
+ f"\n {len(specs)} kernel configurations, dtype={args.dtype}, arch={args.arch}"
)
- print(" " + "-" * 78)
-
- for i, spec in enumerate(specs, 1):
- # Create unique test data per kernel
- np.random.seed(42 + i * 1000)
- A = (np.random.randn(M, K) * 0.1).astype(np_dtype)
- B = (np.random.randn(K, N) * 0.1).astype(np_dtype)
-
- # Create config and setup dispatcher
- config = create_kernel_config(spec, args.dtype, args.arch)
-
- setup = setup_gemm_dispatcher(
- config=config,
- registry_name=f"kernel_{spec.name}",
- verbose=False,
- auto_rebuild=True,
+ print(f"\n {'#':<3} {'Name':<22} {'Tile':<14} {'Pipeline':<10} {'Scheduler':<12}")
+ print(" " + "-" * 64)
+ for i, s in enumerate(specs, 1):
+ print(
+ f" {i:<3} {s.name:<22} {s.tile_m}x{s.tile_n}x{s.tile_k:<6} {s.pipeline:<10} {s.scheduler:<12}"
)
+ reg = Registry(name="basic_gemm")
+ for s in specs:
+ reg.register_kernel(spec_to_config(s, args.dtype, args.arch))
+
+ # Step 2: Parallel JIT build via registry.build()
+ workers = args.workers if args.workers > 0 else None
+ print(
+ f"\n--- Parallel JIT Build ({len(specs)} kernels{f', workers={workers}' if workers else ''}) ---"
+ )
+
+ t0 = time.perf_counter()
+ setups = reg.build(verbose=False, max_workers=workers)
+ jit_build_s = time.perf_counter() - t0
+
+ built = sum(1 for s in setups if s.success)
+ print(f" Built: {built}/{len(specs)} kernels in {jit_build_s:.1f} s")
+
+ if built == 0:
+ print(" ERROR: No kernels built")
+ return 1
+
+ # Step 3: Run each kernel and validate
+ print(f"\n--- Running Kernels (problem {args.size}x{args.size}x{args.size}) ---")
+ np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32
+ M = N = K = args.size
+
+ np.random.seed(42)
+ A = (np.random.randn(M, K) * 0.1).astype(np_dtype)
+ B = (np.random.randn(K, N) * 0.1).astype(np_dtype)
+ C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype(np_dtype)
+
+ print(
+ f"\n {'#':<3} {'Name':<22} {'Tile':<14} {'Time(ms)':>10} {'TFLOPS':>10} {'MaxErr':>10} {'Status':<6}"
+ )
+ print(" " + "-" * 80)
+
+ results = []
+ for i, (spec, setup) in enumerate(zip(specs, setups), 1):
tile = f"{spec.tile_m}x{spec.tile_n}x{spec.tile_k}"
if not setup.success:
print(
- f" {i:<3} {spec.name:<18} {tile:<14} {'N/A':>10} {'N/A':>10} {'N/A':>10} {'FAIL':<8}"
+ f" {i:<3} {spec.name:<22} {tile:<14} {'---':>10} {'---':>10} {'---':>10} {'SKIP':<6}"
)
- results.append((spec.name, False, 0, 0, 0))
- cleanup_gemm()
+ results.append((spec.name, False, 0.0, 0.0, 0.0))
continue
- dispatcher = setup.dispatcher
-
- # Check if size is supported
- if not dispatcher.is_supported(M, N, K):
+ disp = setup.dispatcher
+ if not disp.is_supported(M, N, K):
print(
- f" {i:<3} {spec.name:<18} {tile:<14} {'N/A':>10} {'N/A':>10} {'N/A':>10} {'SKIP':<8}"
+ f" {i:<3} {spec.name:<22} {tile:<14} {'---':>10} {'---':>10} {'---':>10} {'SKIP':<6}"
)
- results.append((spec.name, False, 0, 0, 0))
- cleanup_gemm()
+ results.append((spec.name, False, 0.0, 0.0, 0.0))
continue
- # Run GEMM
- result = dispatcher.run(A, B, M, N, K)
-
- if not result.success:
+ res = disp.run(A, B, M, N, K)
+ if not res.success:
print(
- f" {i:<3} {spec.name:<18} {tile:<14} {'N/A':>10} {'N/A':>10} {'N/A':>10} {'FAIL':<8}"
+ f" {i:<3} {spec.name:<22} {tile:<14} {'---':>10} {'---':>10} {'---':>10} {'FAIL':<6}"
)
- results.append((spec.name, False, 0, 0, 0))
- cleanup_gemm()
+ results.append((spec.name, False, 0.0, 0.0, 0.0))
continue
- # Validate against NumPy reference
- C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype(np_dtype)
- max_err = np.max(np.abs(result.output - C_ref))
-
- # Check if within tolerance
- passed = max_err < 1e-2
- status = "PASS" if passed else "FAIL"
-
+ max_err = float(np.max(np.abs(res.output - C_ref)))
+ ok = max_err < 1e-2
+ tag = "PASS" if ok else "FAIL"
print(
- f" {i:<3} {spec.name:<18} {tile:<14} {result.time_ms:>10.4f} {result.tflops:>10.2f} {max_err:>10.2e} {status:<8}"
+ f" {i:<3} {spec.name:<22} {tile:<14} {res.time_ms:>10.4f} {res.tflops:>10.2f} {max_err:>10.2e} {tag:<6}"
)
- results.append((spec.name, passed, result.time_ms, result.tflops, max_err))
-
- cleanup_gemm()
-
- # =========================================================================
- # Step 3: Summary
- # =========================================================================
- print("\n" + "=" * 70)
- print(" SUMMARY")
- print("=" * 70)
+ results.append((spec.name, ok, res.time_ms, res.tflops, max_err))
+ # Step 4: Summary
passed = sum(1 for r in results if r[1])
failed = len(results) - passed
+ valid = [r for r in results if r[1]]
- print(f"\n Results: {passed}/{len(results)} kernels passed")
- print(f" Problem: {M}x{N}x{K}, dtype={args.dtype}")
-
- if results:
- valid_results = [r for r in results if r[1]]
- if valid_results:
- best = max(valid_results, key=lambda x: x[3])
- print(f"\n Best kernel: {best[0]} ({best[3]:.2f} TFLOPS)")
-
- if failed == 0:
- print("\n *** ALL KERNELS PASSED ***")
- else:
- print(f"\n *** {failed} KERNELS FAILED ***")
-
+ print("\n" + "=" * 70)
+ print(f" Results: {passed}/{len(results)} passed")
+ print(f" Problem: {M}x{N}x{K}, dtype={args.dtype}")
+ print(f" JIT time: {jit_build_s:.1f} s (parallel)")
+ if valid:
+ best = max(valid, key=lambda x: x[3])
+ print(f" Best: {best[0]} ({best[3]:.2f} TFLOPS)")
+ print(f" Status: {'PASS' if failed == 0 else 'FAIL'}")
print("=" * 70)
return 0 if failed == 0 else 1
diff --git a/dispatcher/examples/gemm/python/02_batch_gemm.py b/dispatcher/examples/gemm/python/02_batch_gemm.py
index 039aba2790..745ec1c494 100644
--- a/dispatcher/examples/gemm/python/02_batch_gemm.py
+++ b/dispatcher/examples/gemm/python/02_batch_gemm.py
@@ -6,9 +6,7 @@
"""
Example 02: Batch GEMM
-Runs multiple GEMM operations with different sizes.
-
-Complexity: ★★☆☆☆
+Runs multiple GEMM operations with different sizes using JIT compilation.
Usage:
python3 02_batch_gemm.py
@@ -25,9 +23,8 @@ import numpy as np
from ctypes_utils import (
KernelConfig,
- setup_gemm_dispatcher,
- cleanup_gemm,
- reset_for_example,
+ Registry,
+ detect_gpu_arch,
)
@@ -55,20 +52,20 @@ Examples:
help="Maximum problem size (default: 4096)",
)
parser.add_argument(
- "--arch", default="gfx942", help="Target architecture (default: gfx942)"
+ "--arch",
+ default=detect_gpu_arch(),
+ help="Target architecture (auto-detected from rocminfo)",
)
args = parser.parse_args()
- reset_for_example()
-
print("=" * 60)
print("Example 02: Batch GEMM")
print("=" * 60)
# =========================================================================
- # Step 1: Setup dispatcher
+ # Step 1: JIT build dispatcher
# =========================================================================
- print("\nStep 1: Setup Dispatcher")
+ print("\nStep 1: JIT Build Dispatcher")
config = KernelConfig(
dtype_a=args.dtype,
@@ -80,19 +77,22 @@ Examples:
gfx_arch=args.arch,
)
- setup = setup_gemm_dispatcher(config, registry_name="batch_gemm", verbose=True)
- if not setup.success:
- print(f" ERROR: {setup.error}")
+ reg = Registry(name="batch_gemm")
+ reg.register_kernel(config)
+
+ setups = reg.build(verbose=True)
+ if not setups or not setups[0].success:
+ error = setups[0].error if setups else "No kernels built"
+ print(f" ERROR: {error}")
return 1
- dispatcher = setup.dispatcher
+ dispatcher = setups[0].dispatcher
# =========================================================================
# Step 2: Run batch of different sizes
# =========================================================================
print("\nStep 2: Run Batch")
- # Generate sizes up to max_size
all_sizes = [
(256, 256, 256),
(512, 512, 512),
@@ -135,9 +135,6 @@ Examples:
avg_tflops = (total_ops / 1e12) / (total_time / 1000)
print(f"\n Total: {total_time:.2f} ms, Average: {avg_tflops:.2f} TFLOPS")
- # Cleanup
- cleanup_gemm()
-
print("\n" + "=" * 60)
print("Batch GEMM complete!")
print("=" * 60)
diff --git a/dispatcher/examples/gemm/python/03_benchmark.py b/dispatcher/examples/gemm/python/03_benchmark.py
index bec1b7e2fb..508b3f8b35 100644
--- a/dispatcher/examples/gemm/python/03_benchmark.py
+++ b/dispatcher/examples/gemm/python/03_benchmark.py
@@ -6,9 +6,8 @@
"""
Example 03: Benchmark
-Performance benchmarking with compute-optimized kernel configuration.
-
-Complexity: ★★★☆☆
+Performance benchmarking with compute-optimized kernel configuration
+using JIT compilation.
Usage:
python3 03_benchmark.py
@@ -26,9 +25,8 @@ import numpy as np
from ctypes_utils import (
KernelConfig,
- setup_gemm_dispatcher,
- cleanup_gemm,
- reset_for_example,
+ Registry,
+ detect_gpu_arch,
)
@@ -63,20 +61,20 @@ Examples:
"--iterations", type=int, default=10, help="Benchmark iterations (default: 10)"
)
parser.add_argument(
- "--arch", default="gfx942", help="Target architecture (default: gfx942)"
+ "--arch",
+ default=detect_gpu_arch(),
+ help="Target architecture (auto-detected from rocminfo)",
)
args = parser.parse_args()
- reset_for_example()
-
print("=" * 60)
print("Example 03: Benchmark")
print("=" * 60)
# =========================================================================
- # Step 1: Setup dispatcher with compute-optimized config
+ # Step 1: JIT build dispatcher with compute-optimized config
# =========================================================================
- print("\nStep 1: Setup Dispatcher")
+ print("\nStep 1: JIT Build Dispatcher")
config = KernelConfig(
dtype_a=args.dtype,
@@ -90,12 +88,16 @@ Examples:
gfx_arch=args.arch,
)
- setup = setup_gemm_dispatcher(config, registry_name="benchmark", verbose=True)
- if not setup.success:
- print(f" ERROR: {setup.error}")
+ reg = Registry(name="benchmark")
+ reg.register_kernel(config)
+
+ setups = reg.build(verbose=True)
+ if not setups or not setups[0].success:
+ error = setups[0].error if setups else "No kernels built"
+ print(f" ERROR: {error}")
return 1
- dispatcher = setup.dispatcher
+ dispatcher = setups[0].dispatcher
# =========================================================================
# Step 2: Benchmark
@@ -130,11 +132,9 @@ Examples:
A = np.random.randn(M, K).astype(np_dtype) * 0.1
B = np.random.randn(K, N).astype(np_dtype) * 0.1
- # Warmup
for _ in range(args.warmup):
dispatcher.run(A, B, M, N, K)
- # Benchmark
times = []
for _ in range(args.iterations):
result = dispatcher.run(A, B, M, N, K)
@@ -150,9 +150,6 @@ Examples:
f" {M:>4}x{N:>4}x{K:<4} | {min_time:>10.4f} | {avg_time:>10.4f} | {tflops:>10.2f}"
)
- # Cleanup
- cleanup_gemm()
-
# Summary
print("\n" + "=" * 60)
print("Summary")
diff --git a/dispatcher/examples/gemm/python/04_validation.py b/dispatcher/examples/gemm/python/04_validation.py
index 2fe54c53f7..d56621c3c8 100644
--- a/dispatcher/examples/gemm/python/04_validation.py
+++ b/dispatcher/examples/gemm/python/04_validation.py
@@ -6,9 +6,7 @@
"""
Example 04: Validation
-Validates GPU GEMM against NumPy reference.
-
-Complexity: ★★★☆☆
+Validates GPU GEMM against NumPy reference using JIT compilation.
Usage:
python3 04_validation.py
@@ -26,9 +24,8 @@ import numpy as np
from ctypes_utils import (
KernelConfig,
Validator,
- setup_gemm_dispatcher,
- cleanup_gemm,
- reset_for_example,
+ Registry,
+ detect_gpu_arch,
)
@@ -56,20 +53,20 @@ Examples:
"--atol", type=float, default=1e-2, help="Absolute tolerance (default: 1e-2)"
)
parser.add_argument(
- "--arch", default="gfx942", help="Target architecture (default: gfx942)"
+ "--arch",
+ default=detect_gpu_arch(),
+ help="Target architecture (auto-detected from rocminfo)",
)
args = parser.parse_args()
- reset_for_example()
-
print("=" * 60)
print("Example 04: Validation")
print("=" * 60)
# =========================================================================
- # Step 1: Setup dispatcher
+ # Step 1: JIT build dispatcher
# =========================================================================
- print("\nStep 1: Setup Dispatcher")
+ print("\nStep 1: JIT Build Dispatcher")
config = KernelConfig(
dtype_a=args.dtype,
@@ -81,12 +78,16 @@ Examples:
gfx_arch=args.arch,
)
- setup = setup_gemm_dispatcher(config, registry_name="validation", verbose=True)
- if not setup.success:
- print(f" ERROR: {setup.error}")
+ reg = Registry(name="validation")
+ reg.register_kernel(config)
+
+ setups = reg.build(verbose=True)
+ if not setups or not setups[0].success:
+ error = setups[0].error if setups else "No kernels built"
+ print(f" ERROR: {error}")
return 1
- dispatcher = setup.dispatcher
+ dispatcher = setups[0].dispatcher
# =========================================================================
# Step 2: Run validation tests
@@ -139,9 +140,6 @@ Examples:
print(f" {name:<15} | {M}x{N}x{K:<5} | {max_err:>10.2e} | FAILED")
failed += 1
- # Cleanup
- cleanup_gemm()
-
# Summary
print("\n" + "=" * 60)
total = passed + failed
diff --git a/dispatcher/examples/gemm/python/05_numpy_integration.py b/dispatcher/examples/gemm/python/05_numpy_integration.py
index 493ce46d22..b0af5fa700 100644
--- a/dispatcher/examples/gemm/python/05_numpy_integration.py
+++ b/dispatcher/examples/gemm/python/05_numpy_integration.py
@@ -8,7 +8,6 @@ Example 05: NumPy Integration
Shows how to create a GPU-accelerated matmul wrapper.
-Complexity: ★★☆☆☆
Usage:
python3 05_numpy_integration.py
@@ -29,6 +28,7 @@ from ctypes_utils import (
setup_gemm_dispatcher,
cleanup_gemm,
reset_for_example,
+ detect_gpu_arch,
)
@@ -70,7 +70,9 @@ Examples:
help="Data type (default: fp16)",
)
parser.add_argument(
- "--arch", default="gfx942", help="Target architecture (default: gfx942)"
+ "--arch",
+ default=detect_gpu_arch(),
+ help="Target architecture (auto-detected from rocminfo)",
)
args = parser.parse_args()
diff --git a/dispatcher/examples/gemm/python/06_json_export.py b/dispatcher/examples/gemm/python/06_json_export.py
index 9e062e507b..780032ce06 100644
--- a/dispatcher/examples/gemm/python/06_json_export.py
+++ b/dispatcher/examples/gemm/python/06_json_export.py
@@ -8,7 +8,6 @@ Example 06: JSON Export
Exports registry configuration to JSON.
-Complexity: ★★☆☆☆
Usage:
python3 06_json_export.py
@@ -28,6 +27,7 @@ from ctypes_utils import (
setup_gemm_dispatcher,
cleanup_gemm,
reset_for_example,
+ detect_gpu_arch,
)
@@ -54,7 +54,9 @@ Examples:
help="Data type (default: fp16)",
)
parser.add_argument(
- "--arch", default="gfx942", help="Target architecture (default: gfx942)"
+ "--arch",
+ default=detect_gpu_arch(),
+ help="Target architecture (auto-detected from rocminfo)",
)
args = parser.parse_args()
diff --git a/dispatcher/examples/gemm/python/07_stress_test.py b/dispatcher/examples/gemm/python/07_stress_test.py
index 8160030631..620e66eeaf 100644
--- a/dispatcher/examples/gemm/python/07_stress_test.py
+++ b/dispatcher/examples/gemm/python/07_stress_test.py
@@ -18,7 +18,6 @@ This tests:
- Multiple data types (fp16, bf16)
- Different schedulers (intrawave, interwave)
-Complexity: ★★★★☆
Usage:
python3 07_stress_test.py
@@ -43,6 +42,7 @@ from ctypes_utils import (
cleanup_gemm,
reset_for_example,
Validator,
+ detect_gpu_arch,
)
@@ -413,8 +413,8 @@ Examples:
)
parser.add_argument(
"--arch",
- default="gfx942",
- help="Target architecture (default: gfx942)",
+ default=detect_gpu_arch(),
+ help="Target architecture (auto-detected from rocminfo, override with --arch gfxNNN)",
)
args = parser.parse_args()
diff --git a/dispatcher/examples/gemm/python/08_heuristics.py b/dispatcher/examples/gemm/python/08_heuristics.py
index e2763c0513..acbf1b3ae0 100644
--- a/dispatcher/examples/gemm/python/08_heuristics.py
+++ b/dispatcher/examples/gemm/python/08_heuristics.py
@@ -19,7 +19,6 @@ Heuristic strategies:
- Memory-bound: Optimize memory access for bandwidth-limited cases
- Latency-focused: Minimize kernel launch overhead for small problems
-Complexity: ★★★★☆
Usage:
python3 08_heuristics.py
@@ -43,6 +42,7 @@ from ctypes_utils import (
setup_gemm_dispatcher,
cleanup_gemm,
reset_for_example,
+ detect_gpu_arch,
)
@@ -561,8 +561,8 @@ Examples:
)
parser.add_argument(
"--arch",
- default="gfx942",
- help="Target architecture (default: gfx942)",
+ default=detect_gpu_arch(),
+ help="Target architecture (auto-detected from rocminfo, override with --arch gfxNNN)",
)
args = parser.parse_args()
diff --git a/dispatcher/examples/gemm/python/09_multi_registry.py b/dispatcher/examples/gemm/python/09_multi_registry.py
index 97cbce3497..5d9af239d4 100644
--- a/dispatcher/examples/gemm/python/09_multi_registry.py
+++ b/dispatcher/examples/gemm/python/09_multi_registry.py
@@ -8,7 +8,6 @@ Example 09: Multiple Registries
Demonstrates multiple registries for different optimization targets.
-Complexity: ★★★★★
Usage:
python3 09_multi_registry.py
@@ -30,6 +29,7 @@ from ctypes_utils import (
setup_gemm_dispatcher,
cleanup_gemm,
reset_for_example,
+ detect_gpu_arch,
)
@@ -50,7 +50,9 @@ Examples:
help="Data type (default: fp16)",
)
parser.add_argument(
- "--arch", default="gfx942", help="Target architecture (default: gfx942)"
+ "--arch",
+ default=detect_gpu_arch(),
+ help="Target architecture (auto-detected from rocminfo)",
)
args = parser.parse_args()
diff --git a/dispatcher/examples/gemm/python/10_advanced_benchmark.py b/dispatcher/examples/gemm/python/10_advanced_benchmark.py
index e16e4e271f..b1462478d0 100644
--- a/dispatcher/examples/gemm/python/10_advanced_benchmark.py
+++ b/dispatcher/examples/gemm/python/10_advanced_benchmark.py
@@ -33,6 +33,7 @@ from ctypes_utils import (
setup_gemm_dispatcher,
cleanup_gemm,
reset_for_example,
+ detect_gpu_arch,
)
@@ -69,7 +70,11 @@ def parse_args():
# Kernel configuration
parser.add_argument("--dtype", default="fp16", help="Data type")
parser.add_argument("--pipeline", default="compv4", help="Pipeline type")
- parser.add_argument("--arch", default="gfx942", help="GPU architecture")
+ parser.add_argument(
+ "--arch",
+ default=detect_gpu_arch(),
+ help="GPU architecture (auto-detected from rocminfo)",
+ )
return parser.parse_args()
diff --git a/dispatcher/examples/gemm/python/11_json_import.py b/dispatcher/examples/gemm/python/11_json_import.py
index 06743af406..d19395e553 100644
--- a/dispatcher/examples/gemm/python/11_json_import.py
+++ b/dispatcher/examples/gemm/python/11_json_import.py
@@ -15,7 +15,6 @@ Key Features:
- Use arch_filter validation on loaded configs
- Export to C++ DECL_KERNEL_SET format
-Complexity: ★★★☆☆
Usage:
python3 11_json_import.py
@@ -45,6 +44,7 @@ from ctypes_utils import ( # noqa: E402
cleanup_gemm,
reset_for_example,
validate_kernel_config,
+ detect_gpu_arch,
)
# Sample JSON configuration (embedded for demonstration)
@@ -141,8 +141,8 @@ Examples:
)
parser.add_argument(
"--arch",
- default="gfx942",
- help="Target GPU architecture (default: gfx942)",
+ default=detect_gpu_arch(),
+ help="Target GPU architecture (auto-detected from rocminfo, override with --arch gfxNNN)",
)
args = parser.parse_args()
@@ -236,13 +236,13 @@ Examples:
else:
invalid_count += 1
if invalid_count <= 3: # Show first 3 invalid
- print(f"\n ✗ Invalid: {config.kernel_name()}")
+ print(f"\n FAIL Invalid: {config.kernel_name()}")
for error in result.errors:
print(f" Error: {error}")
print("\n Validation Summary:")
- print(f" ✓ Valid: {valid_count}")
- print(f" ✗ Invalid: {invalid_count}")
+ print(f" OK Valid: {valid_count}")
+ print(f" FAIL Invalid: {invalid_count}")
print(f" Total: {len(configs)}")
# =========================================================================
@@ -275,12 +275,12 @@ Examples:
disp_config, registry_name="json_import", verbose=False
)
if setup.success:
- print(" ✓ Dispatcher setup successful")
+ print(" OK Dispatcher setup successful")
print(
f" Kernel header: {setup.kernel_header.name if setup.kernel_header else 'N/A'}"
)
else:
- print(f" ⚠ Dispatcher setup: {setup.error}")
+ print(f" WARNING Dispatcher setup: {setup.error}")
print(" (This is expected if kernels aren't generated)")
# =========================================================================
diff --git a/dispatcher/examples/gemm/python/README.md b/dispatcher/examples/gemm/python/README.md
index 0a83f3533f..07757b951b 100644
--- a/dispatcher/examples/gemm/python/README.md
+++ b/dispatcher/examples/gemm/python/README.md
@@ -295,5 +295,5 @@ Compilation time scales roughly linearly with kernel count.
## Related Documentation
- [C++ GEMM Examples](../cpp/README.md)
-- [Python Conv Examples](../../conv/python/README.md)
+- [Python Utilities](../../../python/README.md)
- [Main Dispatcher README](../../../README.md)
diff --git a/dispatcher/examples/grouped_conv/cpp/01_basic_grouped_conv.cpp b/dispatcher/examples/grouped_conv/cpp/01_basic_grouped_conv.cpp
new file mode 100644
index 0000000000..b503129c57
--- /dev/null
+++ b/dispatcher/examples/grouped_conv/cpp/01_basic_grouped_conv.cpp
@@ -0,0 +1,203 @@
+// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
+// SPDX-License-Identifier: MIT
+
+// Example 01: Basic Grouped Convolution
+//
+// Demonstrates three declaration patterns (mirrors GEMM 01):
+// 1. AUTOFILL - tile + pipeline only, wave/warp auto-filled
+// 2. AUTOCORRECT - invalid wave(1,1,1) corrected to valid config
+// 3. FULL - all parameters explicit (matches validated gfx942 config)
+//
+// Then runs the forward convolution on GPU and verifies output.
+//
+// Build: cd dispatcher/build && cmake .. && make grouped_conv_01_basic
+
+#include
+#include
+#include
+#include
+#include
+
+#include "ck_tile/core.hpp"
+#include "ck_tile/host.hpp"
+#include "ck_tile/host/convolution_parameter.hpp"
+#include "ck_tile/ops/grouped_convolution.hpp"
+
+#include "ck_tile/dispatcher/grouped_conv_utils.hpp"
+#include "ck_tile/dispatcher/example_args.hpp"
+
+using namespace ck_tile::dispatcher;
+using namespace ck_tile::dispatcher::grouped_conv_utils;
+using GroupedConvSig = grouped_conv_decl::GroupedConvSignature;
+using GroupedConvAlgo = grouped_conv_decl::GroupedConvAlgorithm;
+
+using InDataType = ck_tile::half_t;
+using WeiDataType = ck_tile::half_t;
+using OutDataType = ck_tile::half_t;
+
+// Three declaration patterns -- codegen auto-fills/auto-corrects as needed
+DECL_GROUPED_CONV_KERNEL_SET(
+ basic_conv_kernels,
+ // Pattern 1: AUTOFILL - only tile + pipeline, rest auto-filled
+ .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2),
+ GroupedConvAlgo().tile(1, 128, 128).pipeline("compv4").scheduler("intrawave"),
+ "gfx950")
+ // Pattern 2: AUTOCORRECT - wave(1,1,1) invalid, corrected to (1,4,1)
+ .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2),
+ GroupedConvAlgo()
+ .tile(1, 64, 64)
+ .wave(1, 1, 1)
+ .warp(16, 16, 32)
+ .pipeline("compv3")
+ .scheduler("intrawave")
+ .epilogue("cshuffle")
+ .vector_sizes(4, 8, 8),
+ "gfx950")
+ // Pattern 3: FULL - all parameters explicit (validated config)
+ .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2),
+ GroupedConvAlgo()
+ .tile(1, 128, 128)
+ .wave(2, 2, 1)
+ .warp(32, 32, 16)
+ .pipeline("compv3")
+ .scheduler("intrawave")
+ .epilogue("cshuffle")
+ .vector_sizes(4, 8, 8)
+ .block_per_cu(1),
+ "gfx950"));
+
+int main(int argc, char* argv[])
+{
+ utils::ExampleArgs args("Example 01: Basic Grouped Convolution",
+ "Declaration patterns + GPU execution");
+ args.add_option("--arch", "gfx950", "GPU architecture");
+ args.add_option("--size", "14", "Spatial size (H=W)");
+ args.add_option("-n", "1", "Batch size");
+ args.add_option("-g", "1", "Groups");
+ args.add_option("-c", "64", "Input channels C");
+ args.add_option("-k", "128", "Output channels K");
+
+ if(!args.parse(argc, argv))
+ return 0;
+
+ utils::print_header("Example 01: Basic Grouped Convolution");
+
+ std::string gfx_arch = args.get("--arch", "gfx950");
+ int N = args.get_int("-n", 1);
+ int G = args.get_int("-g", 1);
+ int C = args.get_int("-c", 64);
+ int K = args.get_int("-k", 128);
+ int HW = args.get_int("--size", 14);
+ int Y = 3, X = 3;
+
+ // Step 1: Show declared kernel sets
+ std::cout << "\nStep 1: Declared Kernel Sets\n";
+ GroupedConvKernelSetRegistry::instance().print();
+
+ // Step 2: Register kernels
+ std::cout << "\nStep 2: Register Kernels\n";
+ GroupedConvRegistry registry;
+ registry.set_name("basic_conv");
+ REGISTER_GENERATED_KERNELS(registry, gfx_arch);
+ std::cout << " Registered " << registry.size() << " kernel(s)\n";
+
+ // Step 3: Create dispatcher
+ std::cout << "\nStep 3: Create Dispatcher\n";
+ GroupedConvDispatcher dispatcher(®istry);
+
+ // Step 4: Build problem using CK Tile ConvParam
+ std::cout << "\nStep 4: Problem\n";
+ auto problem = create_grouped_conv2d_problem(N, C, K, HW, HW, Y, X, 1, 1);
+ problem.op = GroupedConvOp::Forward;
+ print_grouped_conv_problem(problem);
+
+ ck_tile::conv::ConvParam conv_param{
+ 2,
+ static_cast(G),
+ static_cast(N),
+ static_cast(K),
+ static_cast(C),
+ {static_cast(Y), static_cast(X)},
+ {static_cast(HW), static_cast(HW)},
+ {1, 1},
+ {1, 1},
+ {1, 1},
+ {1, 1}};
+
+ using InLayout = ck_tile::tensor_layout::convolution::NHWGC;
+ using WeiLayout = ck_tile::tensor_layout::convolution::GKYXC;
+ using OutLayout = ck_tile::tensor_layout::convolution::NHWGK;
+
+ auto in_desc =
+ ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param);
+ auto wei_desc =
+ ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param);
+ auto out_desc =
+ ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param);
+
+ ck_tile::HostTensor input_host(in_desc);
+ ck_tile::HostTensor weight_host(wei_desc);
+ ck_tile::HostTensor output_host(out_desc);
+
+ ck_tile::FillUniformDistribution{-0.5f, 0.5f}(input_host);
+ ck_tile::FillUniformDistribution{-0.5f, 0.5f}(weight_host);
+
+ ck_tile::DeviceMem input_dev(input_host.get_element_space_size_in_bytes());
+ ck_tile::DeviceMem weight_dev(weight_host.get_element_space_size_in_bytes());
+ ck_tile::DeviceMem output_dev(output_host.get_element_space_size_in_bytes());
+
+ input_dev.ToDevice(input_host.data());
+ weight_dev.ToDevice(weight_host.data());
+
+ // Step 5: Select and run
+ std::cout << "\nStep 5: Select and Run\n";
+
+ auto* selected = dispatcher.select_kernel(problem);
+ if(!selected)
+ {
+ std::cerr << " ERROR: No kernel found for problem!\n";
+ return 1;
+ }
+ std::cout << " Selected: " << selected->name() << "\n";
+
+ float time_ms = dispatcher.run(input_dev.GetDeviceBuffer(),
+ weight_dev.GetDeviceBuffer(),
+ output_dev.GetDeviceBuffer(),
+ problem,
+ nullptr);
+
+ double tflops = calculate_conv_tflops(problem, time_ms);
+ std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n";
+ std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n";
+
+ // Step 6: Verify
+ std::cout << "\nStep 6: Verify\n";
+ output_dev.FromDevice(output_host.data());
+
+ size_t total = output_host.get_element_space_size();
+ size_t nonzero = 0;
+ double checksum = 0.0;
+ for(size_t i = 0; i < total; ++i)
+ {
+ float v = static_cast(output_host.data()[i]);
+ if(v != 0.0f)
+ ++nonzero;
+ checksum += v;
+ }
+
+ bool passed = nonzero > 0;
+ std::cout << " Output elements: " << total << "\n";
+ std::cout << " Non-zero: " << nonzero << "/" << total
+ << (nonzero > 0 ? " (kernel produced output)" : " WARNING: all zeros!") << "\n";
+ std::cout << " Checksum: " << std::fixed << std::setprecision(2) << checksum << "\n";
+ std::cout << " Status: " << (passed ? "PASS" : "FAIL") << "\n";
+
+ utils::print_separator();
+ std::cout << "DECLARATION PATTERNS:\n";
+ std::cout << " 1. AUTOFILL: tile + pipeline only, wave/warp auto-filled\n";
+ std::cout << " 2. AUTOCORRECT: invalid wave(1,1,1) corrected\n";
+ std::cout << " 3. FULL: all parameters explicit\n";
+ utils::print_separator();
+
+ return passed ? 0 : 1;
+}
diff --git a/dispatcher/examples/grouped_conv/cpp/02_all_directions.cpp b/dispatcher/examples/grouped_conv/cpp/02_all_directions.cpp
new file mode 100644
index 0000000000..a2f2b9d560
--- /dev/null
+++ b/dispatcher/examples/grouped_conv/cpp/02_all_directions.cpp
@@ -0,0 +1,216 @@
+// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
+// SPDX-License-Identifier: MIT
+
+// Example 02: All Convolution Directions
+//
+// Forward, backward-data, and backward-weight for 2D convolution,
+// each executed on GPU with non-zero verification.
+//
+// Build: cd dispatcher/build && cmake .. && make grouped_conv_02_all_dirs
+
+#include
+#include
+#include
+#include
+#include
+
+#include "ck_tile/core.hpp"
+#include "ck_tile/host.hpp"
+#include "ck_tile/host/convolution_parameter.hpp"
+#include "ck_tile/ops/grouped_convolution.hpp"
+
+#include "ck_tile/dispatcher/grouped_conv_utils.hpp"
+#include "ck_tile/dispatcher/example_args.hpp"
+
+using namespace ck_tile::dispatcher;
+using namespace ck_tile::dispatcher::grouped_conv_utils;
+using GroupedConvSig = grouped_conv_decl::GroupedConvSignature;
+using GroupedConvAlgo = grouped_conv_decl::GroupedConvAlgorithm;
+
+using InDataType = ck_tile::half_t;
+using WeiDataType = ck_tile::half_t;
+using OutDataType = ck_tile::half_t;
+
+DECL_GROUPED_CONV_KERNEL_SET(
+ conv_fwd_2d,
+ .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2),
+ GroupedConvAlgo().tile(1, 128, 128).pipeline("compv4").vector_sizes(4, 8, 8),
+ "gfx950"));
+
+DECL_GROUPED_CONV_KERNEL_SET(
+ conv_bwdd_2d,
+ .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("bwd_data").dims(2),
+ GroupedConvAlgo().tile(1, 128, 128).pipeline("compv3").vector_sizes(4, 8, 8),
+ "gfx950"));
+
+DECL_GROUPED_CONV_KERNEL_SET(
+ conv_bwdw_2d,
+ .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("bwd_weight").dims(2),
+ GroupedConvAlgo()
+ .tile(1, 128, 128)
+ .pipeline("compv3")
+ .memory_op("atomic_add")
+ .vector_sizes(4, 8, 8),
+ "gfx950"));
+
+int main(int argc, char* argv[])
+{
+ utils::ExampleArgs args("Example 02: All Convolution Directions",
+ "Forward/BwdData/BwdWeight with GPU execution and verification");
+ args.add_option("--arch", "gfx950", "GPU architecture");
+
+ if(!args.parse(argc, argv))
+ return 0;
+
+ utils::print_header("Example 02: All Convolution Directions");
+
+ std::string gfx_arch = args.get("--arch", "gfx950");
+
+ GroupedConvRegistry registry;
+ registry.set_name("all_directions");
+ REGISTER_GENERATED_KERNELS(registry, gfx_arch);
+ std::cout << " Registered " << registry.size() << " kernel(s)\n";
+
+ GroupedConvDispatcher dispatcher(®istry);
+
+ const int N = 1, G = 1, C = 64, K = 128, Hi = 14, Wi = 14, Y = 3, X = 3;
+
+ ck_tile::conv::ConvParam conv_param{
+ 2,
+ static_cast(G),
+ static_cast(N),
+ static_cast(K),
+ static_cast(C),
+ {static_cast(Y), static_cast(X)},
+ {static_cast(Hi), static_cast(Wi)},
+ {1, 1},
+ {1, 1},
+ {1, 1},
+ {1, 1}};
+
+ using InLayout = ck_tile::tensor_layout::convolution::NHWGC;
+ using WeiLayout = ck_tile::tensor_layout::convolution::GKYXC;
+ using OutLayout = ck_tile::tensor_layout::convolution::NHWGK;
+
+ auto in_desc =
+ ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param);
+ auto wei_desc =
+ ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param);
+ auto out_desc =
+ ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param);
+
+ ck_tile::HostTensor input(in_desc);
+ ck_tile::HostTensor weight(wei_desc);
+ ck_tile::HostTensor output(out_desc);
+
+ ck_tile::FillUniformDistribution{-0.5f, 0.5f}(input);
+ ck_tile::FillUniformDistribution{-0.5f, 0.5f}(weight);
+
+ ck_tile::DeviceMem input_dev(input.get_element_space_size_in_bytes());
+ ck_tile::DeviceMem weight_dev(weight.get_element_space_size_in_bytes());
+ ck_tile::DeviceMem output_dev(output.get_element_space_size_in_bytes());
+
+ input_dev.ToDevice(input.data());
+ weight_dev.ToDevice(weight.data());
+
+ std::cout << "\n " << std::left << std::setw(12) << "Direction" << std::right << std::setw(10)
+ << "Time(ms)" << std::setw(10) << "TFLOPS" << std::setw(14) << "NonZero"
+ << std::setw(10) << "Status" << "\n";
+ std::cout << " " << std::string(56, '-') << "\n";
+
+ bool all_pass = true;
+
+ auto print_result =
+ [](const char* label, float time_ms, double tflops, size_t nz, size_t total, bool ok) {
+ std::cout << " " << std::left << std::setw(12) << label << std::right << std::fixed
+ << std::setprecision(4) << std::setw(10) << time_ms << std::setprecision(2)
+ << std::setw(10) << tflops << std::setw(14)
+ << (std::to_string(nz) + "/" + std::to_string(total)) << std::setw(10)
+ << (ok ? "OK" : "FAIL") << "\n";
+ };
+
+ // Forward: run(X, W, Y)
+ {
+ auto problem =
+ create_grouped_conv2d_problem(N, C, K, Hi, Wi, Y, X, 1, 1, GroupedConvOp::Forward);
+ float time_ms = dispatcher.run(input_dev.GetDeviceBuffer(),
+ weight_dev.GetDeviceBuffer(),
+ output_dev.GetDeviceBuffer(),
+ problem,
+ nullptr);
+ output_dev.FromDevice(output.data());
+ size_t nz = 0;
+ for(size_t i = 0; i < output.get_element_space_size(); ++i)
+ if(static_cast(output.data()[i]) != 0.0f)
+ ++nz;
+ bool ok = nz > 0;
+ print_result("forward",
+ time_ms,
+ calculate_conv_tflops(problem, time_ms),
+ nz,
+ output.get_element_space_size(),
+ ok);
+ if(!ok)
+ all_pass = false;
+ }
+
+ // Backward Data: run(dY, W, dX)
+ {
+ auto problem =
+ create_grouped_conv2d_problem(N, C, K, Hi, Wi, Y, X, 1, 1, GroupedConvOp::BackwardData);
+ ck_tile::HostTensor dx_host(in_desc);
+ ck_tile::DeviceMem dx_dev(dx_host.get_element_space_size_in_bytes());
+ float time_ms = dispatcher.run(output_dev.GetDeviceBuffer(), // dY (from forward pass)
+ weight_dev.GetDeviceBuffer(), // W
+ dx_dev.GetDeviceBuffer(), // dX (output)
+ problem,
+ nullptr);
+ dx_dev.FromDevice(dx_host.data());
+ size_t nz = 0;
+ for(size_t i = 0; i < dx_host.get_element_space_size(); ++i)
+ if(static_cast(dx_host.data()[i]) != 0.0f)
+ ++nz;
+ bool ok = nz > 0;
+ print_result("bwd_data",
+ time_ms,
+ calculate_conv_tflops(problem, time_ms),
+ nz,
+ dx_host.get_element_space_size(),
+ ok);
+ if(!ok)
+ all_pass = false;
+ }
+
+ // Backward Weight: run(X, dY, dW)
+ {
+ auto problem = create_grouped_conv2d_problem(
+ N, C, K, Hi, Wi, Y, X, 1, 1, GroupedConvOp::BackwardWeight);
+ ck_tile::HostTensor dw_host(wei_desc);
+ ck_tile::DeviceMem dw_dev(dw_host.get_element_space_size_in_bytes());
+ float time_ms = dispatcher.run(input_dev.GetDeviceBuffer(), // X
+ output_dev.GetDeviceBuffer(), // dY
+ dw_dev.GetDeviceBuffer(), // dW (output)
+ problem,
+ nullptr);
+ dw_dev.FromDevice(dw_host.data());
+ size_t nz = 0;
+ for(size_t i = 0; i < dw_host.get_element_space_size(); ++i)
+ if(static_cast(dw_host.data()[i]) != 0.0f)
+ ++nz;
+ bool ok = nz > 0;
+ print_result("bwd_weight",
+ time_ms,
+ calculate_conv_tflops(problem, time_ms),
+ nz,
+ dw_host.get_element_space_size(),
+ ok);
+ if(!ok)
+ all_pass = false;
+ }
+
+ utils::print_separator();
+ std::cout << " Status: " << (all_pass ? "PASS" : "FAIL") << "\n";
+ utils::print_separator();
+
+ return all_pass ? 0 : 1;
+}
diff --git a/dispatcher/examples/grouped_conv/cpp/03_benchmark_validation.cpp b/dispatcher/examples/grouped_conv/cpp/03_benchmark_validation.cpp
new file mode 100644
index 0000000000..12bd87d1a4
--- /dev/null
+++ b/dispatcher/examples/grouped_conv/cpp/03_benchmark_validation.cpp
@@ -0,0 +1,263 @@
+// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
+// SPDX-License-Identifier: MIT
+
+// Example 03: Benchmark and CPU-Reference Validation
+//
+// Runs a 2D grouped conv forward kernel on the GPU via dispatcher.run()
+// and compares against the CK Tile host reference implementation.
+// Exposes warmup/repeat/log_level as CLI args (matches example 20 pattern).
+//
+// Build: cd dispatcher/build && cmake .. && make grouped_conv_03_bench_val
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include "ck_tile/core.hpp"
+#include "ck_tile/host.hpp"
+#include "ck_tile/host/convolution_parameter.hpp"
+#include "ck_tile/ops/grouped_convolution.hpp"
+#include "ck_tile/host/reference/reference_grouped_conv_fwd.hpp"
+
+#include "ck_tile/dispatcher/grouped_conv_utils.hpp"
+#include "ck_tile/dispatcher/example_args.hpp"
+
+using namespace ck_tile::dispatcher;
+using namespace ck_tile::dispatcher::grouped_conv_utils;
+using GroupedConvSig = grouped_conv_decl::GroupedConvSignature;
+using GroupedConvAlgo = grouped_conv_decl::GroupedConvAlgorithm;
+
+using InDataType = ck_tile::half_t;
+using WeiDataType = ck_tile::half_t;
+using OutDataType = ck_tile::half_t;
+using AccDataType = float;
+
+DECL_GROUPED_CONV_KERNEL_SET(
+ bench_kernels,
+ .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2),
+ GroupedConvAlgo().tile(1, 128, 128).pipeline("compv4").vector_sizes(4, 8, 8),
+ "gfx950")
+ .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2),
+ GroupedConvAlgo().tile(1, 64, 64).pipeline("compv3").vector_sizes(4, 8, 8),
+ "gfx950"));
+
+int main(int argc, char* argv[])
+{
+ utils::ExampleArgs args("Example 03: Benchmark & Validation",
+ "GPU execution with CPU reference validation");
+ args.add_option("-n", "1", "Batch size N");
+ args.add_option("-g", "1", "Groups G");
+ args.add_option("-c", "64", "Input channels C");
+ args.add_option("-k", "128", "Output channels K");
+ args.add_option("--size", "14", "Spatial size (H=W)");
+ args.add_option("--warmup", "3", "Warmup iterations");
+ args.add_option("--repeat", "10", "Benchmark iterations");
+ args.add_option("--arch", "gfx950", "GPU architecture");
+ args.add_flag("--no-verify", "Skip CPU validation");
+
+ if(!args.parse(argc, argv))
+ return 0;
+
+ utils::print_header("Example 03: Grouped Conv Benchmark & Validation");
+
+ int N = args.get_int("-n", 1);
+ int G = args.get_int("-g", 1);
+ int C = args.get_int("-c", 64);
+ int K = args.get_int("-k", 128);
+ int Hi = args.get_int("--size", 14);
+ int Wi = Hi;
+ int Y = 3, X = 3;
+ int warmup = args.get_int("--warmup", 3);
+ int repeat = args.get_int("--repeat", 10);
+ bool verify = !args.has("--no-verify");
+ std::string gfx_arch = args.get("--arch", "gfx950");
+
+ std::cout << "\nProblem: N=" << N << " G=" << G << " C=" << C << " K=" << K << " Hi=" << Hi
+ << " Wi=" << Wi << " Y=" << Y << " X=" << X << "\n";
+ std::cout << "Benchmark: warmup=" << warmup << " repeat=" << repeat << "\n";
+
+ // Step 1: Setup tensors using CK Tile descriptors
+ std::cout << "\nStep 1: Setup tensors\n";
+
+ ck_tile::conv::ConvParam conv_param{
+ 2,
+ static_cast(G),
+ static_cast(N),
+ static_cast(K),
+ static_cast(C),
+ {static_cast(Y), static_cast(X)},
+ {static_cast(Hi), static_cast(Wi)},
+ {1, 1},
+ {1, 1},
+ {1, 1},
+ {1, 1}};
+
+ using InLayout = ck_tile::tensor_layout::convolution::NHWGC;
+ using WeiLayout = ck_tile::tensor_layout::convolution::GKYXC;
+ using OutLayout = ck_tile::tensor_layout::convolution::NHWGK;
+
+ auto in_desc =
+ ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param);
+ auto wei_desc =
+ ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param);
+ auto out_desc =
+ ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param);
+
+ ck_tile::HostTensor input(in_desc);
+ ck_tile::HostTensor weight(wei_desc);
+ ck_tile::HostTensor output_gpu(out_desc);
+ ck_tile::HostTensor output_cpu(out_desc);
+
+ ck_tile::FillUniformDistribution{-0.5f, 0.5f}(input);
+ ck_tile::FillUniformDistribution{-0.5f, 0.5f}(weight);
+ output_cpu.SetZero();
+
+ std::cout << " Input: " << input.get_element_space_size() << " elements\n";
+ std::cout << " Weight: " << weight.get_element_space_size() << " elements\n";
+ std::cout << " Output: " << output_gpu.get_element_space_size() << " elements\n";
+
+ // Step 2: CPU reference
+ if(verify)
+ {
+ std::cout << "\nStep 2: CPU Reference\n";
+
+ std::vector strides_v = {1, 1};
+ std::vector dilations_v = {1, 1};
+ std::vector left_pads_v = {1, 1};
+ std::vector right_pads_v = {1, 1};
+
+ ck_tile::reference_grouped_conv_fwd<2, InDataType, WeiDataType, OutDataType>(
+ input, weight, output_cpu, strides_v, dilations_v, left_pads_v, right_pads_v);
+
+ std::cout << " CPU ref[0..7]: ";
+ for(int i = 0; i < std::min(8, static_cast(output_cpu.get_element_space_size())); ++i)
+ std::cout << std::fixed << std::setprecision(4)
+ << static_cast(output_cpu.data()[i]) << " ";
+ std::cout << "\n";
+ }
+
+ // Step 3: GPU execution via dispatcher
+ std::cout << "\nStep 3: GPU Execution (via dispatcher.run)\n";
+
+ GroupedConvRegistry registry;
+ registry.set_name("bench_val");
+ REGISTER_GENERATED_KERNELS(registry, gfx_arch);
+ std::cout << " Registered " << registry.size() << " kernel(s)\n";
+
+ GroupedConvDispatcher dispatcher(®istry);
+
+ auto problem = create_grouped_conv2d_problem(N, C, K, Hi, Wi, Y, X, 1, 1);
+ problem.op = GroupedConvOp::Forward;
+
+ auto* selected = dispatcher.select_kernel(problem);
+ if(!selected)
+ {
+ std::cerr << " ERROR: No kernel found!\n";
+ return 1;
+ }
+ std::cout << " Selected: " << selected->name() << "\n";
+
+ ck_tile::DeviceMem input_dev(input.get_element_space_size_in_bytes());
+ ck_tile::DeviceMem weight_dev(weight.get_element_space_size_in_bytes());
+ ck_tile::DeviceMem output_dev(output_gpu.get_element_space_size_in_bytes());
+
+ input_dev.ToDevice(input.data());
+ weight_dev.ToDevice(weight.data());
+
+ float elapsed_ms = dispatcher.run(input_dev.GetDeviceBuffer(),
+ weight_dev.GetDeviceBuffer(),
+ output_dev.GetDeviceBuffer(),
+ problem,
+ nullptr);
+
+ output_dev.FromDevice(output_gpu.data());
+
+ size_t total = output_gpu.get_element_space_size();
+ std::cout << " GPU out[0..7]: ";
+ for(int i = 0; i < std::min(8, static_cast(total)); ++i)
+ std::cout << std::fixed << std::setprecision(4) << static_cast(output_gpu.data()[i])
+ << " ";
+ std::cout << "\n";
+
+ size_t nonzero_gpu = 0;
+ double gpu_sum = 0.0;
+ for(size_t i = 0; i < total; ++i)
+ {
+ float v = static_cast(output_gpu.data()[i]);
+ if(v != 0.0f)
+ ++nonzero_gpu;
+ gpu_sum += v;
+ }
+ std::cout << " GPU checksum: " << std::fixed << std::setprecision(6) << gpu_sum << "\n";
+ std::cout << " GPU non-zero: " << nonzero_gpu << "/" << total
+ << (nonzero_gpu > 0 ? " (kernel produced output)" : " WARNING: all zeros!") << "\n";
+
+ int Ho = static_cast(problem.Ho());
+ int Wo = static_cast(problem.Wo());
+ double flops = 2.0 * G * N * K * C * Y * X * Ho * Wo;
+ double tflops = flops / (elapsed_ms * 1e9);
+
+ std::cout << " Time: " << std::fixed << std::setprecision(4) << elapsed_ms << " ms\n";
+ std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n";
+
+ // Step 4: Validation
+ bool passed = true;
+ if(verify)
+ {
+ std::cout << "\nStep 4: Validation (GPU vs CPU)\n";
+
+ constexpr float rtol = 1e-2f;
+ constexpr float atol = 1e-2f;
+
+ float max_diff = 0.0f;
+ float max_rel = 0.0f;
+ size_t max_diff_idx = 0;
+ size_t num_elements = output_gpu.get_element_space_size();
+ size_t mismatches = 0;
+
+ for(size_t i = 0; i < num_elements; ++i)
+ {
+ float gpu_val = static_cast(output_gpu.data()[i]);
+ float cpu_val = static_cast(output_cpu.data()[i]);
+ float diff = std::abs(gpu_val - cpu_val);
+ float tol = atol + rtol * std::abs(cpu_val);
+ float rel = diff / (std::abs(cpu_val) + 1e-6f);
+ if(diff > max_diff)
+ {
+ max_diff = diff;
+ max_diff_idx = i;
+ }
+ max_rel = std::max(max_rel, rel);
+ if(diff > tol)
+ ++mismatches;
+ }
+
+ passed = (mismatches == 0);
+
+ std::cout << " Side-by-side at worst element [" << max_diff_idx << "]:\n";
+ std::cout << " GPU: " << std::fixed << std::setprecision(6)
+ << static_cast(output_gpu.data()[max_diff_idx])
+ << " CPU: " << static_cast(output_cpu.data()[max_diff_idx])
+ << " diff: " << std::scientific << max_diff << "\n";
+ std::cout << " Elements: " << num_elements << "\n";
+ std::cout << " Mismatches: " << mismatches << "/" << num_elements << "\n";
+ std::cout << " Max abs diff: " << std::scientific << max_diff << "\n";
+ std::cout << " Max rel diff: " << std::scientific << max_rel << "\n";
+ std::cout << " Status: " << (passed ? "PASSED" : "FAILED") << "\n";
+ }
+
+ utils::print_separator();
+ std::cout << "BENCHMARK & VALIDATION:\n";
+ std::cout << " GPU kernel: " << (selected ? selected->name() : "none") << "\n";
+ std::cout << " Performance: " << std::fixed << std::setprecision(2) << tflops
+ << " TFLOPS\n";
+ std::cout << " CPU reference: reference_grouped_conv_fwd<2>()\n";
+ std::cout << " Validation: " << (passed ? "PASS" : "FAIL") << "\n";
+ utils::print_separator();
+
+ return passed ? 0 : 1;
+}
diff --git a/dispatcher/examples/grouped_conv/cpp/04_registry_json.cpp b/dispatcher/examples/grouped_conv/cpp/04_registry_json.cpp
new file mode 100644
index 0000000000..0e5a6d33be
--- /dev/null
+++ b/dispatcher/examples/grouped_conv/cpp/04_registry_json.cpp
@@ -0,0 +1,154 @@
+// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
+// SPDX-License-Identifier: MIT
+
+// Example 04: Heuristic Selection + JSON Export
+//
+// Demonstrates runtime kernel selection with heuristic ranking,
+// GPU execution, and JSON registry export.
+//
+// Build: cd dispatcher/build && cmake .. && make grouped_conv_04_registry_json
+
+#include
+#include
+#include
+#include
+
+#include "ck_tile/core.hpp"
+#include "ck_tile/host.hpp"
+#include "ck_tile/host/convolution_parameter.hpp"
+#include "ck_tile/ops/grouped_convolution.hpp"
+
+#include "ck_tile/dispatcher/grouped_conv_utils.hpp"
+#include "ck_tile/dispatcher/example_args.hpp"
+
+using namespace ck_tile::dispatcher;
+using namespace ck_tile::dispatcher::grouped_conv_utils;
+using GroupedConvSig = grouped_conv_decl::GroupedConvSignature;
+using GroupedConvAlgo = grouped_conv_decl::GroupedConvAlgorithm;
+
+using InDataType = ck_tile::half_t;
+using WeiDataType = ck_tile::half_t;
+using OutDataType = ck_tile::half_t;
+
+// Two tile configs for heuristic selection
+DECL_GROUPED_CONV_KERNEL_SET(
+ heuristic_kernels,
+ .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2),
+ GroupedConvAlgo().tile(1, 128, 128).pipeline("compv4").vector_sizes(4, 8, 8),
+ "gfx950")
+ .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2),
+ GroupedConvAlgo().tile(1, 64, 64).pipeline("compv3").vector_sizes(4, 8, 8),
+ "gfx950"));
+
+std::vector conv_heuristic(const GroupedConvProblem& problem)
+{
+ int64_t spatial = problem.Ho() * problem.Wo();
+ if(spatial > 400)
+ return {"128x128", "64x64"};
+ return {"64x64", "128x128"};
+}
+
+int main(int argc, char* argv[])
+{
+ utils::ExampleArgs args("Example 04: Heuristic + JSON",
+ "Runtime kernel selection and JSON export");
+ args.add_option("--arch", "gfx950", "GPU architecture");
+
+ if(!args.parse(argc, argv))
+ return 0;
+
+ utils::print_header("Example 04: Heuristic Selection + JSON Export");
+
+ std::string gfx_arch = args.get("--arch", "gfx950");
+
+ // Step 1: Register
+ std::cout << "\nStep 1: Register Kernels" << std::endl;
+ GroupedConvRegistry registry;
+ registry.set_name("heuristic_conv");
+ REGISTER_GENERATED_KERNELS(registry, gfx_arch);
+ std::cout << " Registered " << registry.size() << " kernel(s)" << std::endl;
+
+ // Step 2: Heuristic dispatcher
+ std::cout << "\nStep 2: Heuristic Dispatcher" << std::endl;
+ GroupedConvDispatcher dispatcher(®istry);
+ dispatcher.set_strategy(GroupedConvDispatcher::SelectionStrategy::Heuristic);
+ dispatcher.set_heuristic(conv_heuristic);
+
+ // Step 3: Select kernels (no GPU yet)
+ std::cout << "\nStep 3: Kernel Selection" << std::endl;
+
+ auto problem = create_grouped_conv2d_problem(1, 64, 128, 14, 14, 3, 3, 1, 1);
+
+ auto* selected = dispatcher.select_kernel(problem);
+ std::cout << " Selected: " << (selected ? selected->name() : "none") << std::endl;
+
+ // Step 4: GPU execution
+ std::cout << "\nStep 4: GPU Execution" << std::endl;
+
+ ck_tile::conv::ConvParam cp{
+ 2,
+ static_cast(1),
+ static_cast(1),
+ static_cast(128),
+ static_cast(64),
+ {static_cast(3), static_cast(3)},
+ {static_cast(14), static_cast(14)},
+ {1, 1},
+ {1, 1},
+ {1, 1},
+ {1, 1}};
+
+ using InLayout = ck_tile::tensor_layout::convolution::NHWGC;
+ using WeiLayout = ck_tile::tensor_layout::convolution::GKYXC;
+ using OutLayout = ck_tile::tensor_layout::convolution::NHWGK;
+
+ std::cout << " Creating tensors..." << std::endl;
+ auto in_d = ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(cp);
+ auto wei_d = ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(cp);
+ auto out_d = ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(cp);
+
+ ck_tile::HostTensor input(in_d);
+ ck_tile::HostTensor weight(wei_d);
+ ck_tile::HostTensor output(out_d);
+ ck_tile::FillUniformDistribution{-0.5f, 0.5f}(input);
+ ck_tile::FillUniformDistribution{-0.5f, 0.5f}(weight);
+
+ std::cout << " Allocating device memory..." << std::endl;
+ ck_tile::DeviceMem in_dev(input.get_element_space_size_in_bytes());
+ ck_tile::DeviceMem wei_dev(weight.get_element_space_size_in_bytes());
+ ck_tile::DeviceMem out_dev(output.get_element_space_size_in_bytes());
+ in_dev.ToDevice(input.data());
+ wei_dev.ToDevice(weight.data());
+
+ std::cout << " Launching kernel..." << std::endl;
+ float time_ms = dispatcher.run(in_dev.GetDeviceBuffer(),
+ wei_dev.GetDeviceBuffer(),
+ out_dev.GetDeviceBuffer(),
+ problem,
+ nullptr);
+
+ std::cout << " Reading back..." << std::endl;
+ out_dev.FromDevice(output.data());
+ size_t nz = 0;
+ for(size_t i = 0; i < output.get_element_space_size(); ++i)
+ if(static_cast(output.data()[i]) != 0.0f)
+ ++nz;
+
+ std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms"
+ << std::endl;
+ std::cout << " TFLOPS: " << std::setprecision(2) << calculate_conv_tflops(problem, time_ms)
+ << std::endl;
+ std::cout << " NonZero: " << nz << "/" << output.get_element_space_size() << std::endl;
+
+ // Step 5: JSON export
+ std::cout << "\nStep 5: JSON Export" << std::endl;
+ std::string json = registry.export_json(false);
+ std::cout << " JSON size: " << json.size() << " bytes" << std::endl;
+
+ bool passed = nz > 0;
+ utils::print_separator();
+ std::cout << " Status: " << (passed ? "PASS" : "FAIL") << "\n";
+ utils::print_separator();
+
+ return passed ? 0 : 1;
+}
diff --git a/dispatcher/examples/grouped_conv/cpp/05_bwd_data.cpp b/dispatcher/examples/grouped_conv/cpp/05_bwd_data.cpp
new file mode 100644
index 0000000000..35595bb14c
--- /dev/null
+++ b/dispatcher/examples/grouped_conv/cpp/05_bwd_data.cpp
@@ -0,0 +1,183 @@
+// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
+// SPDX-License-Identifier: MIT
+
+// Example 05: Backward Data with CPU Reference Validation
+//
+// Computes dX = ConvBwdData(dY, W) on GPU via dispatcher.run()
+// and validates against ck_tile::reference_grouped_conv_bwd_data.
+//
+// Build: cd dispatcher/build && cmake .. && make grouped_conv_05_bwd_data
+
+#include
+#include
+#include
+#include
+#include
+
+#include "ck_tile/core.hpp"
+#include "ck_tile/host.hpp"
+#include "ck_tile/host/convolution_parameter.hpp"
+#include "ck_tile/ops/grouped_convolution.hpp"
+#include "ck_tile/host/reference/reference_grouped_conv_bwd_data.hpp"
+
+#include "ck_tile/dispatcher/grouped_conv_utils.hpp"
+#include "ck_tile/dispatcher/example_args.hpp"
+
+using namespace ck_tile::dispatcher;
+using namespace ck_tile::dispatcher::grouped_conv_utils;
+using GroupedConvSig = grouped_conv_decl::GroupedConvSignature;
+using GroupedConvAlgo = grouped_conv_decl::GroupedConvAlgorithm;
+
+using InDataType = ck_tile::half_t;
+using WeiDataType = ck_tile::half_t;
+using OutDataType = ck_tile::half_t;
+
+DECL_GROUPED_CONV_KERNEL_SET(
+ bwd_data_kernels,
+ .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("bwd_data").dims(2),
+ GroupedConvAlgo()
+ .tile(1, 128, 128)
+ .pipeline("compv3")
+ .scheduler("intrawave")
+ .vector_sizes(4, 8, 8),
+ "gfx950"));
+
+int main(int argc, char* argv[])
+{
+ utils::ExampleArgs args("Example 05: Backward Data Validation",
+ "dX = ConvBwdData(dY, W) with CPU reference");
+ args.add_option("--arch", "gfx950", "GPU architecture");
+ args.add_option("-n", "1", "Batch size");
+ args.add_option("-c", "64", "Input channels");
+ args.add_option("-k", "128", "Output channels");
+ args.add_option("--size", "14", "Spatial size (H=W)");
+
+ if(!args.parse(argc, argv))
+ return 0;
+
+ utils::print_header("Example 05: Backward Data with CPU Validation");
+
+ std::string gfx_arch = args.get("--arch", "gfx950");
+ int N = args.get_int("-n", 1), G = 1;
+ int C = args.get_int("-c", 64), K = args.get_int("-k", 128);
+ int Hi = args.get_int("--size", 14), Wi = Hi, Y = 3, X = 3;
+
+ // Setup
+ ck_tile::conv::ConvParam conv_param{
+ 2,
+ static_cast(G),
+ static_cast(N),
+ static_cast(K),
+ static_cast(C),
+ {static_cast(Y), static_cast(X)},
+ {static_cast(Hi), static_cast(Wi)},
+ {1, 1},
+ {1, 1},
+ {1, 1},
+ {1, 1}};
+
+ using InLayout = ck_tile::tensor_layout::convolution::NHWGC;
+ using WeiLayout = ck_tile::tensor_layout::convolution::GKYXC;
+ using OutLayout = ck_tile::tensor_layout::convolution::NHWGK;
+
+ auto in_desc =
+ ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param);
+ auto wei_desc =
+ ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param);
+ auto out_desc =
+ ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param);
+
+ // dY (gradient from next layer) and W (weight) are inputs; dX is output
+ ck_tile::HostTensor dy(out_desc);
+ ck_tile::HostTensor weight(wei_desc);
+ ck_tile::HostTensor dx_gpu(in_desc);
+ ck_tile::HostTensor dx_cpu(in_desc);
+
+ ck_tile::FillUniformDistribution{-0.5f, 0.5f}(dy);
+ ck_tile::FillUniformDistribution{-0.5f, 0.5f}(weight);
+ dx_cpu.SetZero();
+
+ // CPU reference
+ std::cout << "\nStep 1: CPU Reference (bwd_data)\n";
+ std::vector strides_v = {1, 1};
+ std::vector dilations_v = {1, 1};
+ std::vector left_pads_v = {1, 1};
+ std::vector right_pads_v = {1, 1};
+
+ ck_tile::reference_grouped_conv_bwd_data<2, InDataType, WeiDataType, OutDataType>(
+ dx_cpu, weight, dy, strides_v, dilations_v, left_pads_v, right_pads_v);
+ std::cout << " CPU complete\n";
+
+ // GPU execution via dispatcher
+ std::cout << "\nStep 2: GPU Execution (via dispatcher.run)\n";
+
+ GroupedConvRegistry registry;
+ registry.set_name("bwd_data");
+ REGISTER_GENERATED_KERNELS(registry, gfx_arch);
+
+ GroupedConvDispatcher dispatcher(®istry);
+
+ auto problem =
+ create_grouped_conv2d_problem(N, C, K, Hi, Wi, Y, X, 1, 1, GroupedConvOp::BackwardData);
+
+ auto* selected = dispatcher.select_kernel(problem);
+ if(!selected)
+ {
+ std::cerr << " ERROR: No bwd_data kernel found!\n";
+ return 1;
+ }
+ std::cout << " Selected: " << selected->name() << "\n";
+
+ ck_tile::DeviceMem dy_dev(dy.get_element_space_size_in_bytes());
+ ck_tile::DeviceMem wei_dev(weight.get_element_space_size_in_bytes());
+ ck_tile::DeviceMem dx_dev(dx_gpu.get_element_space_size_in_bytes());
+
+ dy_dev.ToDevice(dy.data());
+ wei_dev.ToDevice(weight.data());
+
+ // dispatcher.run(dY, W, dX, problem) for bwd_data
+ float time_ms = dispatcher.run(dy_dev.GetDeviceBuffer(),
+ wei_dev.GetDeviceBuffer(),
+ dx_dev.GetDeviceBuffer(),
+ problem,
+ nullptr);
+
+ dx_dev.FromDevice(dx_gpu.data());
+
+ double tflops = (time_ms > 0) ? calculate_conv_tflops(problem, time_ms) : 0;
+ std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n";
+ std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n";
+
+ // Validation
+ std::cout << "\nStep 3: Validation (GPU vs CPU)\n";
+
+ size_t num_elements = dx_gpu.get_element_space_size();
+ float max_abs = 0, max_rel = 0;
+ size_t mismatches = 0;
+ constexpr float rtol = 5e-2f, atol = 5e-2f;
+
+ for(size_t i = 0; i < num_elements; ++i)
+ {
+ float gv = static_cast(dx_gpu.data()[i]);
+ float cv = static_cast(dx_cpu.data()[i]);
+ float d = std::abs(gv - cv);
+ float r = d / (std::abs(cv) + 1e-6f);
+ max_abs = std::max(max_abs, d);
+ max_rel = std::max(max_rel, r);
+ if(d > atol + rtol * std::abs(cv))
+ ++mismatches;
+ }
+
+ bool passed = (mismatches == 0);
+ std::cout << " Elements: " << num_elements << "\n";
+ std::cout << " Mismatches: " << mismatches << "\n";
+ std::cout << " Max abs diff: " << std::scientific << max_abs << "\n";
+ std::cout << " Max rel diff: " << std::scientific << max_rel << "\n";
+
+ utils::print_separator();
+ std::cout << " dX = ConvBwdData(dY, W)\n";
+ std::cout << " Status: " << (passed ? "PASS" : "FAIL") << "\n";
+ utils::print_separator();
+
+ return passed ? 0 : 1;
+}
diff --git a/dispatcher/examples/grouped_conv/cpp/06_bwd_weight.cpp b/dispatcher/examples/grouped_conv/cpp/06_bwd_weight.cpp
new file mode 100644
index 0000000000..41cb75aecf
--- /dev/null
+++ b/dispatcher/examples/grouped_conv/cpp/06_bwd_weight.cpp
@@ -0,0 +1,188 @@
+// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
+// SPDX-License-Identifier: MIT
+
+// Example 06: Backward Weight with CPU Reference Validation
+//
+// Computes dW = ConvBwdWeight(X, dY) on GPU via dispatcher.run()
+// and validates against ck_tile::reference_grouped_conv_bwd_weight.
+//
+// Build: cd dispatcher/build && cmake .. && make grouped_conv_06_bwd_weight
+
+#include
+#include
+#include
+#include
+#include
+
+#include "ck_tile/core.hpp"
+#include "ck_tile/host.hpp"
+#include "ck_tile/host/convolution_parameter.hpp"
+#include "ck_tile/ops/grouped_convolution.hpp"
+#include "ck_tile/host/reference/reference_grouped_conv_bwd_weight.hpp"
+
+#include "ck_tile/dispatcher/grouped_conv_utils.hpp"
+#include "ck_tile/dispatcher/example_args.hpp"
+
+using namespace ck_tile::dispatcher;
+using namespace ck_tile::dispatcher::grouped_conv_utils;
+using GroupedConvSig = grouped_conv_decl::GroupedConvSignature;
+using GroupedConvAlgo = grouped_conv_decl::GroupedConvAlgorithm;
+
+using InDataType = ck_tile::half_t;
+using WeiDataType = ck_tile::half_t;
+using OutDataType = ck_tile::half_t;
+
+DECL_GROUPED_CONV_KERNEL_SET(
+ bwd_weight_kernels,
+ .add(GroupedConvSig().dtype("fp16").layout("nhwgc").conv_type("bwd_weight").dims(2),
+ GroupedConvAlgo()
+ .tile(1, 128, 128)
+ .pipeline("compv3")
+ .scheduler("intrawave")
+ .memory_op("atomic_add")
+ .vector_sizes(4, 8, 8),
+ "gfx950"));
+
+int main(int argc, char* argv[])
+{
+ utils::ExampleArgs args("Example 06: Backward Weight Validation",
+ "dW = ConvBwdWeight(X, dY) with CPU reference");
+ args.add_option("--arch", "gfx950", "GPU architecture");
+ args.add_option("-n", "1", "Batch size");
+ args.add_option("-c", "64", "Input channels");
+ args.add_option("-k", "128", "Output channels");
+ args.add_option("--size", "14", "Spatial size (H=W)");
+ args.add_option("--split-k", "1", "Split-K factor for bwd_weight (k_batch)");
+
+ if(!args.parse(argc, argv))
+ return 0;
+
+ utils::print_header("Example 06: Backward Weight with CPU Validation");
+
+ std::string gfx_arch = args.get("--arch", "gfx950");
+ int N = args.get_int("-n", 1), G = 1;
+ int C = args.get_int("-c", 64), K = args.get_int("-k", 128);
+ int Hi = args.get_int("--size", 14), Wi = Hi, Y = 3, X = 3;
+
+ // Setup
+ ck_tile::conv::ConvParam conv_param{
+ 2,
+ static_cast(G),
+ static_cast(N),
+ static_cast(K),
+ static_cast(C),
+ {static_cast(Y), static_cast(X)},
+ {static_cast(Hi), static_cast(Wi)},
+ {1, 1},
+ {1, 1},
+ {1, 1},
+ {1, 1}};
+
+ using InLayout = ck_tile::tensor_layout::convolution::NHWGC;
+ using WeiLayout = ck_tile::tensor_layout::convolution::GKYXC;
+ using OutLayout = ck_tile::tensor_layout::convolution::NHWGK;
+
+ auto in_desc =
+ ck_tile::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param);
+ auto wei_desc =
+ ck_tile::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param);
+ auto out_desc =
+ ck_tile::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param);
+
+ // X (input) and dY (gradient) are inputs; dW is output
+ ck_tile::HostTensor input(in_desc);
+ ck_tile::HostTensor dy(out_desc);
+ ck_tile::HostTensor dw_gpu(wei_desc);
+ ck_tile::HostTensor dw_cpu(wei_desc);
+
+ ck_tile::FillUniformDistribution{-0.5f, 0.5f}(input);
+ ck_tile::FillUniformDistribution{-0.5f, 0.5f}(dy);
+ dw_cpu.SetZero();
+
+ // CPU reference
+ std::cout << "\nStep 1: CPU Reference (bwd_weight)\n";
+ std::vector strides_v = {1, 1};
+ std::vector dilations_v = {1, 1};
+ std::vector left_pads_v = {1, 1};
+ std::vector right_pads_v = {1, 1};
+
+ ck_tile::reference_grouped_conv_bwd_weight<2, InDataType, WeiDataType, OutDataType>(
+ input, dw_cpu, dy, strides_v, dilations_v, left_pads_v, right_pads_v);
+ std::cout << " CPU complete\n";
+
+ // GPU execution
+ std::cout << "\nStep 2: GPU Execution (via dispatcher.run)\n";
+
+ GroupedConvRegistry registry;
+ registry.set_name("bwd_weight");
+ REGISTER_GENERATED_KERNELS(registry, gfx_arch);
+
+ GroupedConvDispatcher dispatcher(®istry);
+
+ auto problem =
+ create_grouped_conv2d_problem(N, C, K, Hi, Wi, Y, X, 1, 1, GroupedConvOp::BackwardWeight);
+ problem.split_k = args.get_int("--split-k", 1);
+
+ auto* selected = dispatcher.select_kernel(problem);
+ if(!selected)
+ {
+ std::cerr << " ERROR: No bwd_weight kernel found!\n";
+ return 1;
+ }
+ std::cout << " Selected: " << selected->name() << "\n";
+
+ ck_tile::DeviceMem in_dev(input.get_element_space_size_in_bytes());
+ ck_tile::DeviceMem dy_dev(dy.get_element_space_size_in_bytes());
+ ck_tile::DeviceMem dw_dev(dw_gpu.get_element_space_size_in_bytes());
+
+ in_dev.ToDevice(input.data());
+ dy_dev.ToDevice(dy.data());
+ if(problem.split_k > 1)
+ dw_dev.SetZero();
+
+ // dispatcher.run(X, dY, dW, problem) for bwd_weight
+ float time_ms = dispatcher.run(in_dev.GetDeviceBuffer(),
+ dy_dev.GetDeviceBuffer(),
+ dw_dev.GetDeviceBuffer(),
+ problem,
+ nullptr);
+
+ dw_dev.FromDevice(dw_gpu.data());
+
+ double tflops = (time_ms > 0) ? calculate_conv_tflops(problem, time_ms) : 0;
+ std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n";
+ std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n";
+
+ // Validation
+ std::cout << "\nStep 3: Validation (GPU vs CPU)\n";
+
+ size_t num_elements = dw_gpu.get_element_space_size();
+ float max_abs = 0, max_rel = 0;
+ size_t mismatches = 0;
+ constexpr float rtol = 5e-2f, atol = 5e-2f;
+
+ for(size_t i = 0; i < num_elements; ++i)
+ {
+ float gv = static_cast(dw_gpu.data()[i]);
+ float cv = static_cast(dw_cpu.data()[i]);
+ float d = std::abs(gv - cv);
+ float r = d / (std::abs(cv) + 1e-6f);
+ max_abs = std::max(max_abs, d);
+ max_rel = std::max(max_rel, r);
+ if(d > atol + rtol * std::abs(cv))
+ ++mismatches;
+ }
+
+ bool passed = (mismatches == 0);
+ std::cout << " Elements: " << num_elements << "\n";
+ std::cout << " Mismatches: " << mismatches << "\n";
+ std::cout << " Max abs diff: " << std::scientific << max_abs << "\n";
+ std::cout << " Max rel diff: " << std::scientific << max_rel << "\n";
+
+ utils::print_separator();
+ std::cout << " dW = ConvBwdWeight(X, dY)\n";
+ std::cout << " Status: " << (passed ? "PASS" : "FAIL") << "\n";
+ utils::print_separator();
+
+ return passed ? 0 : 1;
+}
diff --git a/dispatcher/examples/grouped_conv/cpp/07_multi_tile_benchmark.cpp b/dispatcher/examples/grouped_conv/cpp/07_multi_tile_benchmark.cpp
new file mode 100644
index 0000000000..5c95f2c45a
--- /dev/null
+++ b/dispatcher/examples/grouped_conv/cpp/07_multi_tile_benchmark.cpp
@@ -0,0 +1,226 @@
+// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
+// SPDX-License-Identifier: MIT
+
+// Example 07: Multi-Tile Benchmark
+//
+// Benchmarks multiple tile configurations across ResNet-like problem sizes.
+// Exposes warmup, repeat, and init method as CLI args (matching CK Tile
+// example 20 patterns).
+//
+// Build: cd dispatcher/build && cmake .. && make grouped_conv_07_benchmark
+
+#include