Merge branch 'develop' into whole_k_prefetch_n0loop

This commit is contained in:
Qianfeng Zhang
2026-02-24 09:49:40 +00:00
2279 changed files with 228242 additions and 60567 deletions

View File

@@ -1,30 +0,0 @@
resources:
repositories:
- repository: pipelines_repo
type: github
endpoint: ROCm
name: ROCm/ROCm
variables:
- group: common
- template: /.azuredevops/variables-global.yml@pipelines_repo
trigger:
batch: true
branches:
include:
- develop
- amd-develop
paths:
exclude:
- .github
- docs
- '.*.y*ml'
- '*.md'
- Jenkinsfile
- LICENSE
pr: none
jobs:
- template: ${{ variables.CI_COMPONENT_PATH }}/composable_kernel.yml@pipelines_repo

12
.github/CODEOWNERS vendored
View File

@@ -1,8 +1,8 @@
* @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @tenpercent @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd
* @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @tenpercent @ThomasNing @coderfeli @cgmillette @shumway @vidyasagar-amd @vpietila-amd @Snektron
# Documentation files
docs/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd @ddembeckAMD
*.md @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd @ddembeckAMD
*.rst @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd @ddembeckAMD
.readthedocs.yaml @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd @ddembeckAMD
docs/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @cgmillette @shumway @vidyasagar-amd @ddembeckAMD @vpietila-amd @Snektron
*.md @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @cgmillette @shumway @vidyasagar-amd @ddembeckAMD @vpietila-amd @Snektron
*.rst @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @cgmillette @shumway @vidyasagar-amd @ddembeckAMD @vpietila-amd @Snektron
.readthedocs.yaml @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @cgmillette @shumway @vidyasagar-amd @ddembeckAMD @vpietila-amd @Snektron
# Header directory for Doxygen documentation
library/include/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @aska-0096 @cgmillette @shumway @vidyasagar-amd
library/include/ @ROCm/rocm-documentation @illsilin @carlushuang @qianfengz @aosewski @poyenc @geyyer @bartekxk @andriy-ca @afagaj @asleepzzz @ThomasNing @coderfeli @cgmillette @shumway @vidyasagar-amd @vpietila-amd @Snektron

View File

@@ -1,143 +0,0 @@
import fnmatch
import json
import os
from pathlib import Path
import subprocess
import sys
from typing import Iterable, Optional, Mapping
def gha_set_output(vars: Mapping[str, str | Path]):
"""Sets values in a step's output parameters.
This appends to the file located at the $GITHUB_OUTPUT environment variable.
See
* https://docs.github.com/en/actions/reference/workflow-commands-for-github-actions#setting-an-output-parameter
* https://docs.github.com/en/actions/writing-workflows/choosing-what-your-workflow-does/passing-information-between-jobs
"""
print(f"Setting github output:\n{vars}")
step_output_file = os.getenv("GITHUB_OUTPUT")
if not step_output_file:
print(" Warning: GITHUB_OUTPUT env var not set, can't set github outputs")
return
with open(step_output_file, "a") as f:
f.writelines(f"{k}={str(v)}" + "\n" for k, v in vars.items())
def get_modified_paths(base_ref: str) -> Optional[Iterable[str]]:
"""Returns the paths of modified files relative to the base reference."""
try:
return subprocess.run(
["git", "diff", "--name-only", base_ref],
stdout=subprocess.PIPE,
check=True,
text=True,
timeout=60,
).stdout.splitlines()
except TimeoutError:
print(
"Computing modified files timed out. Not using PR diff to determine"
" jobs to run.",
file=sys.stderr,
)
return None
GITHUB_WORKFLOWS_CI_PATTERNS = [
"therock*",
]
def is_path_workflow_file_related_to_ci(path: str) -> bool:
return any(
fnmatch.fnmatch(path, ".github/workflows/" + pattern)
for pattern in GITHUB_WORKFLOWS_CI_PATTERNS
) or any(
fnmatch.fnmatch(path, ".github/scripts/" + pattern)
for pattern in GITHUB_WORKFLOWS_CI_PATTERNS
)
def check_for_workflow_file_related_to_ci(paths: Optional[Iterable[str]]) -> bool:
if paths is None:
return False
return any(is_path_workflow_file_related_to_ci(p) for p in paths)
# Paths matching any of these patterns are considered to have no influence over
# build or test workflows so any related jobs can be skipped if all paths
# modified by a commit/PR match a pattern in this list.
SKIPPABLE_PATH_PATTERNS = [
"docs/*",
"*.gitignore",
"*.md",
"*.pre-commit-config.*",
"*LICENSE",
"Jenkinsfile",
".github/ISSUE_TEMPLATE/*",
".github/CODEOWNERS",
".github/*.md",
".github/dependabot.yml",
]
def is_path_skippable(path: str) -> bool:
"""Determines if a given relative path to a file matches any skippable patterns."""
return any(fnmatch.fnmatch(path, pattern) for pattern in SKIPPABLE_PATH_PATTERNS)
def check_for_non_skippable_path(paths: Optional[Iterable[str]]) -> bool:
"""Returns true if at least one path is not in the skippable set."""
if paths is None:
return False
return any(not is_path_skippable(p) for p in paths)
def should_ci_run_given_modified_paths(paths: Optional[Iterable[str]]) -> bool:
"""Returns true if CI workflows should run given a list of modified paths."""
if paths is None:
print("No files were modified, skipping TheRock CI jobs")
return False
paths_set = set(paths)
github_workflows_paths = set(
[p for p in paths if p.startswith(".github/workflows")]
)
other_paths = paths_set - github_workflows_paths
related_to_ci = check_for_workflow_file_related_to_ci(github_workflows_paths)
contains_other_non_skippable_files = check_for_non_skippable_path(other_paths)
print("should_ci_run_given_modified_paths findings:")
print(f" contains_other_non_skippable_files: {contains_other_non_skippable_files}")
if related_to_ci:
print("Enabling build jobs since a related workflow file was modified")
return True
elif contains_other_non_skippable_files:
print("Enabling TheRock CI jobs since a non-skippable path was modified")
return True
else:
print(
"Only unrelated and/or skippable paths were modified, skipping TheRock CI jobs"
)
return False
def main(args):
base_ref = args.get("base_ref")
modified_paths = get_modified_paths(base_ref)
print("modified_paths (max 200):", modified_paths[:200])
enable_jobs = should_ci_run_given_modified_paths(modified_paths)
output = {"enable_therock_ci": json.dumps(enable_jobs)}
gha_set_output(output)
if __name__ == "__main__":
args = {}
args["base_ref"] = os.environ.get("BASE_REF", "HEAD^1")
main(args)

View File

@@ -1,16 +0,0 @@
name: pre-commit
on:
pull_request:
push:
branches: [develop]
jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
with:
python-version: '3.12'
- uses: pre-commit/action@v3.0.1

View File

@@ -1,145 +0,0 @@
name: TheRock CI Linux
on:
workflow_call:
inputs:
cmake_options:
type: string
amdgpu_families:
type: string
test_runs_on:
type: string
permissions:
contents: read
jobs:
therock-build-linux:
name: Build Linux Packages
runs-on: azure-linux-scale-rocm
permissions:
id-token: write
container:
image: ghcr.io/rocm/therock_build_manylinux_x86_64@sha256:2f3ebd0beb04c449fdb36933e54bdc69483b914fb9005594d3fc9444c206b54b
options: -v /runner/config:/home/awsconfig/
env:
AMDGPU_FAMILIES: ${{ inputs.amdgpu_families }}
TEATIME_FORCE_INTERACTIVE: 0
AWS_SHARED_CREDENTIALS_FILE: /home/awsconfig/credentials.ini
CACHE_DIR: ${{ github.workspace }}/.container-cache
# The ccache.conf will be written by setup_ccache.py before this gets used.
CCACHE_CONFIGPATH: ${{ github.workspace }}/.ccache/ccache.conf
steps:
- name: "Checking out repository for rocm-libraries"
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
repository: "ROCm/rocm-libraries"
- name: Pull DVC files for rocm-libraries # LOGNAME details here https://github.com/ROCm/rocm-libraries/pull/1617
run: |
if command -v dvc &> /dev/null; then
echo "dvc detected"
else
echo "Warning, dvc not detected!"
fi
LOGNAME=github-runner dvc pull -v
- name: Checkout composable_kernel repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
path: "composable_kernel"
- name: Checkout TheRock repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
repository: "ROCm/TheRock"
path: "TheRock"
ref: f3f77a3161922df3eee006b888b439d75b2b4668 # 2025-10-29 commit
- name: Setup ccache
run: |
./TheRock/build_tools/setup_ccache.py \
--config-preset "github-oss-presubmit" \
--dir "$(dirname $CCACHE_CONFIGPATH)" \
--local-path "$CACHE_DIR/ccache"
echo "namespace = ext_composable_kernel" >> $CCACHE_CONFIGPATH
echo "[*] ccache_config contents:"
cat $CCACHE_CONFIGPATH
- name: Runner Health Settings
run: |
./TheRock/build_tools/health_status.py
- name: Fetch sources
run: |
./TheRock/build_tools/fetch_sources.py --jobs 12 --no-include-rocm-libraries --no-include-ml-frameworks
- name: Patch rocm-libraries
run: |
git config --global --add safe.directory '*'
# Remove patches here if they cannot be applied cleanly, and they have not been deleted from TheRock repo
rm -f ./TheRock/patches/amd-mainline/rocm-libraries/0008-Revert-remove-options-no-enumerate-966.patch
git -c user.name="therockbot" -c "user.email=therockbot@amd.com" am --whitespace=nowarn ./TheRock/patches/amd-mainline/rocm-libraries/*.patch
- name: Install python deps
run: |
pip install -r TheRock/requirements.txt
pip freeze
- name: Configure Projects
env:
amdgpu_families: ${{ env.AMDGPU_FAMILIES }}
package_version: ADHOCBUILD
extra_cmake_options: ${{ inputs.cmake_options }}
BUILD_DIR: build
run: |
python3 TheRock/build_tools/github_actions/build_configure.py
- name: Build TheRock
run: cmake --build TheRock/build
- name: Build therock-archives
run: cmake --build TheRock/build --target therock-archives
- name: Report
if: ${{ !cancelled() }}
run: |
echo "Full SDK du:"
echo "------------"
du -h -d 1 TheRock/build/dist/rocm
echo "Artifact Archives:"
echo "------------------"
ls -lh TheRock/build/artifacts/*.tar.xz
echo "Artifacts:"
echo "----------"
du -h -d 1 TheRock/build/artifacts
echo "CCache Stats:"
echo "-------------"
ccache -s -v
tail -v -n +1 .ccache/compiler_check_cache/* > TheRock/build/logs/ccache_compiler_check_cache.log
- name: Configure AWS Credentials for non-forked repos
if: ${{ always() && !github.event.pull_request.head.repo.fork }}
uses: aws-actions/configure-aws-credentials@7474bc4690e29a8392af63c5b98e7449536d5c3a # v4.3.1
with:
aws-region: us-east-2
role-to-assume: arn:aws:iam::692859939525:role/therock-artifacts-external
- name: Post Build Upload
if: always()
run: |
python3 TheRock/build_tools/github_actions/post_build_upload.py \
--run-id ${{ github.run_id }} \
--artifact-group ${{ env.AMDGPU_FAMILIES }} \
--build-dir TheRock/build \
--upload
therock-test-linux:
name: "Test"
needs: [therock-build-linux]
uses: ./.github/workflows/therock-test-packages.yml
with:
project_to_test: "miopen"
amdgpu_families: ${{ inputs.amdgpu_families }}
test_runs_on: ${{ inputs.test_runs_on }}
platform: "linux"

View File

@@ -1,88 +0,0 @@
name: TheRock CI for composable_kernel
on:
push:
branches:
- develop
workflow_dispatch:
pull_request:
types:
- opened
- synchronize
branches:
- mainline
- release/*
- release-staging/*
- develop
permissions:
contents: read
concurrency:
# A PR number if a pull request and otherwise the commit hash. This cancels
# queued and in-progress runs for the same PR (presubmit) or commit
# (postsubmit). The workflow name is prepended to avoid conflicts between
# different workflows.
group: ${{ github.workflow }}-${{ github.event.number || github.sha }}
cancel-in-progress: true
jobs:
setup:
runs-on: ubuntu-24.04
env:
# The commit being checked out is the merge commit for a PR. Its first
# parent will be the tip of the base branch.
BASE_REF: HEAD^
outputs:
enable_therock_ci: ${{ steps.configure.outputs.enable_therock_ci }}
steps:
- name: "Checking out repository"
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with:
# We need the parent commit to do a diff
fetch-depth: 2
- name: "Configuring CI options"
id: configure
run: python .github/scripts/therock_configure_ci.py
therock-ci-linux:
name: TheRock CI Linux
needs: setup
if: ${{ needs.setup.outputs.enable_therock_ci == 'true' }}
permissions:
contents: read
id-token: write
uses: ./.github/workflows/therock-ci-linux.yml
secrets: inherit
with:
cmake_options: >-
-DTHEROCK_ENABLE_COMPOSABLE_KERNEL=ON
-DTHEROCK_ENABLE_MIOPEN=ON
-DTHEROCK_ENABLE_ALL=OFF
-DTHEROCK_USE_EXTERNAL_COMPOSABLE_KERNEL=ON
-DTHEROCK_COMPOSABLE_KERNEL_SOURCE_DIR=../composable_kernel
-DTHEROCK_USE_EXTERNAL_ROCM_LIBRARIES=ON
-DTHEROCK_ROCM_LIBRARIES_SOURCE_DIR=../
amdgpu_families: "gfx94X-dcgpu"
test_runs_on: "linux-mi325-1gpu-ossci-rocm"
therock_ci_summary:
name: TheRock CI Summary
if: always()
needs:
- setup
- therock-ci-linux
runs-on: ubuntu-24.04
steps:
- name: Output failed jobs
run: |
echo '${{ toJson(needs) }}'
FAILED_JOBS="$(echo '${{ toJson(needs) }}' \
| jq --raw-output \
'map_values(select(.result!="success" and .result!="skipped")) | keys | join(",")' \
)"
if [[ "${FAILED_JOBS}" != "" ]]; then
echo "The following jobs failed: ${FAILED_JOBS}"
exit 1
fi

View File

@@ -1,72 +0,0 @@
name: Test component
on:
workflow_call:
inputs:
artifact_run_id:
type: string
default: ""
amdgpu_families:
type: string
test_runs_on:
type: string
platform:
type: string
component:
type: string
permissions:
contents: read
jobs:
test_component:
name: 'Test ${{ fromJSON(inputs.component).job_name }} (shard ${{ matrix.shard }} of ${{ fromJSON(inputs.component).total_shards }})'
runs-on: ${{ inputs.test_runs_on }}
container:
image: ${{ inputs.platform == 'linux' && 'ghcr.io/rocm/no_rocm_image_ubuntu24_04@sha256:4150afe4759d14822f0e3f8930e1124f26e11f68b5c7b91ec9a02b20b1ebbb98' || null }}
options: --ipc host
--group-add video
--device /dev/kfd
--device /dev/dri
--group-add 110
--env-file /etc/podinfo/gha-gpu-isolation-settings
strategy:
fail-fast: false
matrix:
# The shard array is based on "total_shards" from "fetch_test_configurations.py"
# The test executable will shard based on the array. (ex: [1, 2, 3, 4] = four test shards)
shard: ${{ fromJSON(inputs.component).shard_arr }}
defaults:
run:
shell: bash
env:
VENV_DIR: ${{ github.workspace }}/.venv
ARTIFACT_RUN_ID: "${{ inputs.artifact_run_id != '' && inputs.artifact_run_id || github.run_id }}"
OUTPUT_ARTIFACTS_DIR: "./build"
THEROCK_BIN_DIR: "./build/bin"
AMDGPU_FAMILIES: ${{ inputs.amdgpu_families }}
steps:
- name: Checkout Repository
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
with:
repository: "ROCm/TheRock"
ref: f3f77a3161922df3eee006b888b439d75b2b4668 # 2025-10-29 commit
- name: Run setup test environment workflow
uses: './.github/actions/setup_test_environment'
with:
ARTIFACT_RUN_ID: ${{ env.ARTIFACT_RUN_ID }}
ARTIFACT_GROUP: ${{ inputs.amdgpu_families }}
OUTPUT_ARTIFACTS_DIR: ${{ env.OUTPUT_ARTIFACTS_DIR }}
VENV_DIR: ${{ env.VENV_DIR }}
FETCH_ARTIFACT_ARGS: ${{ fromJSON(inputs.component).fetch_artifact_args }}
IS_PR_FROM_FORK: ${{ github.event.pull_request.head.repo.fork }}
- name: Test
timeout-minutes: ${{ fromJSON(inputs.component).timeout_minutes }}
env:
SHARD_INDEX: ${{ matrix.shard }}
TOTAL_SHARDS: ${{ fromJSON(inputs.component).total_shards }}
run: |
${{ fromJSON(inputs.component).test_script }}

View File

@@ -1,54 +0,0 @@
name: TheRock Test Packages
on:
workflow_call:
inputs:
project_to_test:
type: string
amdgpu_families:
type: string
test_runs_on:
type: string
platform:
type: string
permissions:
contents: read
jobs:
configure_test_matrix:
name: "Configure test matrix"
runs-on: ubuntu-24.04
if: ${{ inputs.test_runs_on != '' }}
outputs:
components: ${{ steps.configure.outputs.components }}
steps:
- name: "Checking out repository"
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
repository: "ROCm/TheRock"
ref: f3f77a3161922df3eee006b888b439d75b2b4668 # 2025-10-29 commit
- name: "Configuring CI options"
env:
PLATFORM: ${{ inputs.platform }}
project_to_test: ${{ inputs.project_to_test }}
id: configure
run: python ./build_tools/github_actions/fetch_test_configurations.py
test_components:
name: 'Test ${{ matrix.components.job_name }}'
needs: [configure_test_matrix]
# skip tests if no test matrix to run
if: ${{ needs.configure_test_matrix.outputs.components != '[]' }}
strategy:
fail-fast: false
matrix:
components: ${{ fromJSON(needs.configure_test_matrix.outputs.components) }}
uses: './.github/workflows/therock-test-component.yml'
with:
artifact_run_id: ${{ github.run_id }}
amdgpu_families: ${{ inputs.amdgpu_families }}
test_runs_on: ${{ inputs.test_runs_on }}
platform: ${{ inputs.platform }}
component: ${{ toJSON(matrix.components) }}

29
.gitignore vendored
View File

@@ -36,6 +36,9 @@ tags
# Editors
.vscode
# CMake formatting configuration (local)
.cmake-format.yaml
# Cline
.cline*
@@ -78,9 +81,35 @@ CMakeUserPresets.json
# Python cache
__pycache__/
# Cache directories
.cache/
.ck_tile_cache/
ck_tile_cache/
**/kernel_cache/
**/.kernel_cache/
# Dispatcher kernel cache (user-generated, can be large)
dispatcher/**/kernel_cache/
dispatcher/**/.kernel_cache/
dispatcher/**/cached_kernels/
dispatcher/**/*.hsaco
dispatcher/**/*.co
# Dispatcher generated JSON exports
dispatcher/**/*_kernels.json
dispatcher/**/dispatcher_kernels.json
# Generated test data
test_data/*
!test_data/*.py
!test_data/*.sh
!test_data/requirements.txt
# Exceptions to build* patterns above
# The experimental/builder directory should be tracked despite matching build*
!experimental/builder
!experimental/builder/**
experimental/grouped_convolution_tile_instances/instances/*
!experimental/grouped_convolution_tile_instances/instances/*.in
!experimental/grouped_convolution_tile_instances/instances/*.inc
experimental/grouped_convolution_tile_instances/*.inc

View File

@@ -20,21 +20,21 @@ repos:
)$
- repo: local
hooks:
# - id: copyright-year-checker
# name: copyright-year-checker
# entry: script/check_copyright_year.sh
# verbose: false
# language: script
# types: [c++]
- id: copyright-header-checker
name: Check copyright headers
entry: projects/composablekernel/script/check_copyright_year.sh
verbose: false
language: script
types_or: [c++, python, shell, cmake]
- id: remove-exec-bit
name: Remove executable bit from non-executable files
entry: script/remove_exec_bit.sh
entry: projects/composablekernel/script/remove_exec_bit.sh
language: script
types_or: [c++, text]
verbose: true
- id: remod-ck-tile
name: Run ck_tile remod.py
entry: python script/remod_for_ck_tile.py
entry: python projects/composablekernel/script/remod_for_ck_tile.py
language: python
files: '^(include|example)/ck_tile/.*$'
additional_dependencies:

View File

@@ -2,32 +2,58 @@
Documentation for Composable Kernel available at [https://rocm.docs.amd.com/projects/composable_kernel/en/latest/](https://rocm.docs.amd.com/projects/composable_kernel/en/latest/).
## (Unreleased) Composable Kernel 1.3.0
### Added
* Added preshuffleB support for abquant mode in blockscale GEMM.
* Added support for explicit GEMM in CK_TILE grouped convolution forward and backward weight.
* Added TF32 convolution support on gfx942 and gfx950 in CK. It could be enabled/disabled via `DTYPES` of "tf32".
* Added streamingllm sink support for FMHA FWD, include qr_ks_vs, qr_async and splitkv pipelines.
* Added support for microscaling (MX) FP8/FP4 mixed data types to Flatmm pipeline.
* Added support for fp8 dynamic tensor-wise quantization of fp8 fmha fwd kernel.
* Added FP8 KV cache support for FMHA batch prefill.
* Added support for gfx1153 target.
* Added FMHA batch prefill kernel support for several KV cache layouts, flexible page sizes, and different lookup table configurations.
* Added gpt-oss sink support for FMHA FWD, include qr_ks_vs, qr_async, qr_async_trload and splitkv pipelines.
* Added persistent async input scheduler for CK Tile universal GEMM kernels to support asynchronous input streaming.
* Added FP8 block scale quantization for FMHA forward kernel.
* Added gfx11 support for FMHA.
### Changed
### Upcoming changes
## Composable Kernel 1.2.0 for ROCm 7.2.0
### Added
* Added tests for f8 x bf8 on CompV3, and f8 x bf8 with K_BlockSize 32 on CompV4
* Added CK-Tile dispatcher - a unified kernel dispatch, code generation and architecture-based kernel filtering system with with C++ and Python frontends starting with GEMM support.
* Added support for bf16 data type to grouped_gemm and grouped_gemm_preshuffle.
* Added support for mixed precision fp8 x bf8 universal GEMM and weight preshuffle GEMM
* Added a compute async pipeline in the CK TILE universal GEMM on gfx950
* Added support for B Tensor type pk_int4_t in the CK TILE weight preshuffle GEMM.
* Added Col-Col-Row-Col layout support for aquant mode in blockscale GEMM.
* Added support for mixed precision fp8 x bf8 universal GEMM and weight preshuffle GEMM.
* Added a compute async pipeline in the CK Tile universal GEMM on gfx950.
* Added support for B Tensor type `pk_int4_t` in the CK Tile weight preshuffle GEMM.
* Added the new api to load different memory sizes to SGPR.
* Added support for B Tensor Preshuffle in CK TILE Grouped GEMM.
* Added support for B Tensor preshuffle in CK Tile grouped GEMM.
* Added a basic copy kernel example and supporting documentation for new CK Tile developers.
* Added support for grouped_gemm kernels to perform multi_d elementwise operation.
* Added support for Multiple ABD GEMM
* Added support for grouped GEMM kernels to perform Multi D elementwise operation.
* Added support for multiple ABD GEMM.
* Added benchmarking support for tile engine GEMM Multi D.
* Added block scaling support in CK_TILE GEMM, allowing flexible use of quantization matrices from either A or B operands.
* Added the row-wise column-wise quantization for CK_TILE GEMM & CK_TILE Grouped GEMM.
* Added support for f32 to FMHA (fwd/bwd).
* Added tensor-wise quantization for CK_TILE GEMM.
* Added block scaling support in CK Tile GEMM, allowing flexible use of quantization matrices from either A or B operands.
* Added the row-wise column-wise quantization for CK Tile GEMM and CK Tile grouped GEMM.
* Added support for f32 to FMHA (forward and backward).
* Added tensor-wise quantization for CK Tile GEMM.
* Added support for batched contraction kernel.
* Added WMMA (gfx12) support for FMHA.
* Added pooling kernel in CK_TILE
* Added top-k sigmoid kernel in CK_TILE
* Added the blockscale 2D support for CK_TILE GEMM.
* Added Flatmm pipeline for microscaling (MX) FP8/FP4 data types
* Added reduce and multi reduction kernels
### Changed
* Removed `BlockSize` in `make_kernel` and `CShuffleEpilogueProblem` to support Wave32 in CK_TILE (#2594)
* Removed `BlockSize` in `make_kernel` and `CShuffleEpilogueProblem` to support Wave32 in CK Tile (#2594)
* Added an optional template parameter `Arch` (`gfx9_t`, `gfx12_t` etc.) to `make_kernel` to support linking multiple object files that have the same kernel compiled for different architectures.
* FMHA examples and tests can be built for multiple architectures (gfx9, gfx950, gfx12) at the same time.
@@ -73,11 +99,12 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
* Added Ping-pong scheduler support for GEMM operation along the K dimension.
* Added rotating buffer feature for CK_Tile GEMM.
* Added int8 support for CK_TILE GEMM.
* Added CK Tile Epilogue Chainer framework for composable epilogue sequences in GEMM operations
### Optimized
* Optimize the gemm multiply multiply preshuffle & lds bypass with Pack of KGroup and better instruction layout.
* Added Vectorize Transpose optimization for CK Tile
* Added Vectorize Transpose optimization for CK Tile
* Added the asynchronous copy for gfx950
### Changed

View File

@@ -1,4 +1,7 @@
cmake_minimum_required(VERSION 3.14)
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
cmake_minimum_required(VERSION 3.21)
if(POLICY CMP0140)
# policies CMP0140 not known to CMake until 3.25
cmake_policy(SET CMP0140 NEW)
@@ -28,21 +31,31 @@ endif()
# Default installation path
if(NOT WIN32)
set(CMAKE_INSTALL_PREFIX "/opt/rocm" CACHE PATH "")
else()
set(CMAKE_INSTALL_PREFIX "C:/dist/TheRock" CACHE PATH "")
endif()
# Enable ASAN when THEROCK_SANITIZER is set to ASAN or HOST_ASAN
if(THEROCK_SANITIZER STREQUAL "ASAN" OR THEROCK_SANITIZER STREQUAL "HOST_ASAN")
set(ENABLE_ASAN_PACKAGING ON)
message(STATUS "Enabling ASAN for Composable Kernel (THEROCK_SANITIZER=${THEROCK_SANITIZER})")
endif()
set(version 1.2.0)
# Check support for CUDA/HIP in Cmake
project(composable_kernel VERSION ${version} LANGUAGES CXX HIP)
project(composable_kernel VERSION ${version} LANGUAGES CXX)
include(CTest)
option(ENABLE_CLANG_CPP_CHECKS "Enables clang tidy, cppcheck" ON)
option(MIOPEN_REQ_LIBS_ONLY "Build only the MIOpen required libraries" OFF)
option(HIPTENSOR_REQ_LIBS_ONLY "Build only the HipTensor required libraries" OFF)
option(CK_EXPERIMENTAL_BUILDER "Enable experimental builder" OFF)
option(BUILD_MHA_LIB "Build the static library for flash attention" OFF)
option(FORCE_DISABLE_XDL "Skip compiling XDL specific instances (even if supported GPUs are included in GPU_TARGETS)" OFF)
option(FORCE_DISABLE_WMMA "Skip compiling WMMA specific instances (even if supported GPUs are included in GPU_TARGETS)" OFF)
if(CK_EXPERIMENTAL_BUILDER)
add_definitions(-DCK_EXPERIMENTAL_BUILDER)
include_directories(${PROJECT_SOURCE_DIR}/experimental/builder/include)
include_directories(${PROJECT_SOURCE_DIR}/experimental/builder/include)
endif()
# Usage: for customized Python location cmake -DCK_USE_ALTERNATIVE_PYTHON="/opt/Python-3.8.13/bin/python3.8"
@@ -87,6 +100,10 @@ if (DTYPES)
add_definitions(-DCK_ENABLE_FP32)
set(CK_ENABLE_FP32 "ON")
endif()
if (DTYPES MATCHES "tf32")
# definition will be added based on the GPU target in the following section
set(CK_ENABLE_TF32 "ON")
endif()
if (DTYPES MATCHES "fp64")
add_definitions(-DCK_ENABLE_FP64)
set(CK_ENABLE_FP64 "ON")
@@ -101,6 +118,7 @@ else()
set(CK_ENABLE_INT8 "ON")
set(CK_ENABLE_FP16 "ON")
set(CK_ENABLE_FP32 "ON")
set(CK_ENABLE_TF32 "ON")
set(CK_ENABLE_FP64 "ON")
set(CK_ENABLE_BF16 "ON")
set(CK_ENABLE_FP8 "ON")
@@ -113,6 +131,9 @@ add_compile_options(-Wno-pass-failed)
add_compile_options(-Wno-switch-default)
add_compile_options(-Wno-unique-object-duplication)
# Increase the number of max elements in fold expressions
add_compile_options(-fbracket-depth=1024)
# add -Og -gdwarf64 for debug builds
add_compile_options(
"$<$<CONFIG:Debug>:-Og>"
@@ -152,7 +173,13 @@ execute_process(COMMAND "${GIT_EXECUTABLE}" rev-parse HEAD OUTPUT_VARIABLE COMMI
configure_file(include/ck/version.h.in ${CMAKE_CURRENT_BINARY_DIR}/include/ck/version.h)
set(ROCM_SYMLINK_LIBS OFF)
find_package(ROCM REQUIRED PATHS /opt/rocm)
if (WIN32)
find_package(ROCmCMakeBuildTools REQUIRED PATHS C:/dist/TheRock)
set(HIP_PLATFORM "amd" CACHE STRING "HIP platform")
else()
find_package(ROCM REQUIRED PATHS /opt/rocm)
endif()
include(ROCMInstallTargets)
include(ROCMPackageConfigHelpers)
@@ -179,7 +206,10 @@ if(GPU_TARGETS)
else()
set(USER_GPU_TARGETS 0)
endif()
find_package(hip REQUIRED)
enable_language(HIP)
# No assumption that HIP kernels are launched with uniform block size for backward compatibility
# SWDEV-413293 and https://reviews.llvm.org/D155213
math(EXPR hip_VERSION_FLAT "(${hip_VERSION_MAJOR} * 1000 + ${hip_VERSION_MINOR}) * 100000 + ${hip_VERSION_PATCH}")
@@ -204,7 +234,7 @@ if(NOT ENABLE_ASAN_PACKAGING)
endif()
else()
#build CK only for xnack-supported targets when using ASAN
set(CK_GPU_TARGETS "gfx908:xnack+;gfx90a:xnack+;gfx942:xnack+")
set(CK_GPU_TARGETS "gfx908:xnack+;gfx90a:xnack+;gfx942:xnack+;gfx950:xnack+")
endif()
#if user set GPU_ARCHS on the cmake command line, overwrite default target list with user's list
@@ -229,16 +259,21 @@ message(STATUS "Building CK for the following targets: ${SUPPORTED_GPU_TARGETS}"
# Cache SUPPORTED_GPU_TARGETS for debug
set(SUPPORTED_GPU_TARGETS "${SUPPORTED_GPU_TARGETS}" CACHE STRING "List of supported GPU targets")
if (SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx11|gfx12")
if (SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx11|gfx12" AND NOT FORCE_DISABLE_XDL)
message(STATUS "Enabling XDL instances")
add_definitions(-DCK_USE_XDL)
set(CK_USE_XDL "ON")
endif()
if (SUPPORTED_GPU_TARGETS MATCHES "gfx94" OR SUPPORTED_GPU_TARGETS MATCHES "gfx95")
if ((SUPPORTED_GPU_TARGETS MATCHES "gfx94" OR SUPPORTED_GPU_TARGETS MATCHES "gfx95") AND NOT FORCE_DISABLE_XDL)
message(STATUS "Enabling XDL FP8 gemms on native architectures")
add_definitions(-DCK_USE_GFX94)
set(CK_USE_GFX94 "ON")
endif()
if (SUPPORTED_GPU_TARGETS MATCHES "gfx950" AND NOT FORCE_DISABLE_XDL)
message(STATUS "Enabling XDL FP8 gemms on gfx950")
add_definitions(-DCK_USE_GFX950)
set(CK_USE_GFX950 "ON")
endif()
# new macro CK_TILE_USE_WMMA in order to separately compile examples for MFMA/WMMA
set(CK_TILE_USE_WMMA 0)
@@ -247,7 +282,7 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx10")
add_definitions(-DCK_GFX1030_SUPPORT)
endif()
if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12")
if ((SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12") AND NOT FORCE_DISABLE_WMMA)
message(STATUS "Enabling WMMA instances")
add_definitions(-DCK_USE_WMMA)
set(CK_USE_WMMA "ON")
@@ -257,7 +292,7 @@ endif()
# define the macro with the current value (0 or 1)
add_definitions(-DCK_TILE_USE_WMMA=${CK_TILE_USE_WMMA})
if (SUPPORTED_GPU_TARGETS MATCHES "gfx12")
if (SUPPORTED_GPU_TARGETS MATCHES "gfx12" AND NOT FORCE_DISABLE_WMMA)
message(STATUS "Enabling WMMA FP8 gemms on native architectures")
add_definitions(-DCK_USE_WMMA_FP8)
set(CK_USE_WMMA_FP8 "ON")
@@ -277,6 +312,15 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx950")
set(CK_GFX950_SUPPORT "ON")
endif()
if ((SUPPORTED_GPU_TARGETS MATCHES "gfx942" OR SUPPORTED_GPU_TARGETS MATCHES "gfx95") AND CK_ENABLE_TF32)
add_definitions(-DCK_ENABLE_TF32)
set(CK_ENABLE_TF32 "ON")
else()
message(STATUS "Disabling TF32 instances")
remove_definitions(-DCK_ENABLE_TF32)
set(CK_ENABLE_TF32 "OFF")
endif()
option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF)
if(CK_USE_FP8_ON_UNSUPPORTED_ARCH AND (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx908"))
add_definitions(-DCK_USE_FP8_ON_UNSUPPORTED_ARCH)
@@ -614,7 +658,7 @@ if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERS
add_compile_options(-fdiagnostics-color=always)
endif()
if(NOT MIOPEN_REQ_LIBS_ONLY)
if(NOT MIOPEN_REQ_LIBS_ONLY AND NOT HIPTENSOR_REQ_LIBS_ONLY)
# make check runs the entire set of examples and tests
add_custom_target(check COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} USES_TERMINAL)
# make smoke runs the tests and examples that runs within 30 seconds on gfx90a
@@ -625,7 +669,9 @@ endif()
file(GLOB_RECURSE INSTANCE_FILES "${PROJECT_SOURCE_DIR}/*/device_*_instance.cpp")
# Optimization: Search only in library/src where all instance files actually live
# (was searching entire source tree, taking ~40s instead of <1s)
file(GLOB_RECURSE INSTANCE_FILES "${PROJECT_SOURCE_DIR}/library/src/*/device_*_instance.cpp")
file(GLOB dir_list RELATIVE ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu ${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu/*)
set(CK_DEVICE_INSTANCES)
FOREACH(subdir_path ${dir_list})
@@ -646,6 +692,9 @@ IF(IS_DIRECTORY "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu
if(("${cmake_instance}" MATCHES "fp32" OR "${cmake_instance}" MATCHES "_f32") AND DTYPES MATCHES "fp32")
set(add_inst 1)
endif()
if(("${cmake_instance}" MATCHES "tf32" OR "${cmake_instance}" MATCHES "_tf32") AND DTYPES MATCHES "tf32")
set(add_inst 1)
endif()
if(("${cmake_instance}" MATCHES "fp64" OR "${cmake_instance}" MATCHES "_f64") AND DTYPES MATCHES "fp64")
set(add_inst 1)
endif()
@@ -667,12 +716,18 @@ ENDFOREACH()
add_custom_target(instances DEPENDS utility;${CK_DEVICE_INSTANCES} SOURCES ${INSTANCE_FILES})
option(MIOPEN_REQ_LIBS_ONLY "Build only the MIOpen required libraries" OFF)
option(HIPTENSOR_REQ_LIBS_ONLY "Build only the HipTensor required libraries" OFF)
option(DISABLE_OFFLOAD_COMPRESS "Disable offload compress compiler flag when building instances" OFF)
option(BUILD_MHA_LIB "Build the static library for flash attention" OFF)
add_subdirectory(library)
if(NOT GPU_ARCHS AND USER_GPU_TARGETS AND NOT MIOPEN_REQ_LIBS_ONLY)
if (CK_EXPERIMENTAL_BUILDER)
add_subdirectory(experimental/builder)
add_subdirectory(experimental/grouped_convolution_tile_instances)
endif()
if(NOT GPU_ARCHS AND USER_GPU_TARGETS AND NOT MIOPEN_REQ_LIBS_ONLY AND NOT HIPTENSOR_REQ_LIBS_ONLY)
rocm_package_setup_component(tests
LIBRARY_NAME composablekernel
PACKAGE_NAME tests # Prevent -static suffix on package name
@@ -695,7 +750,7 @@ if(NOT GPU_ARCHS AND USER_GPU_TARGETS AND NOT MIOPEN_REQ_LIBS_ONLY)
endif()
endif()
if (NOT MIOPEN_REQ_LIBS_ONLY)
if (NOT MIOPEN_REQ_LIBS_ONLY AND NOT HIPTENSOR_REQ_LIBS_ONLY)
rocm_package_setup_component(profiler
LIBRARY_NAME composablekernel
PACKAGE_NAME ckprofiler
@@ -703,10 +758,6 @@ if (NOT MIOPEN_REQ_LIBS_ONLY)
add_subdirectory(profiler)
endif()
if (CK_EXPERIMENTAL_BUILDER)
add_subdirectory(experimental/builder)
endif()
if(CK_USE_CODEGEN AND (SUPPORTED_GPU_TARGETS MATCHES "gfx9" OR GPU_ARCHS))
add_subdirectory(codegen)
endif()
@@ -739,6 +790,16 @@ rocm_install(FILES
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck/
)
if(CK_EXPERIMENTAL_BUILDER)
rocm_install(DIRECTORY
${PROJECT_SOURCE_DIR}/experimental/builder/include/ck_tile/builder
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck_tile
)
set(CK_TILE_SRC_FOLDER ${CMAKE_SOURCE_DIR}/include/ck_tile/)
rocm_install(DIRECTORY ${CK_TILE_SRC_FOLDER} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck_tile)
endif()
set(CPACK_RESOURCE_FILE_LICENSE "${CMAKE_CURRENT_SOURCE_DIR}/LICENSE")
set(CPACK_RPM_PACKAGE_LICENSE "MIT")

91
CMakePresets.json Normal file
View File

@@ -0,0 +1,91 @@
{
"version": 3,
"cmakeMinimumRequired": {
"major": 3,
"minor": 21,
"patch": 0
},
"configurePresets": [
{
"name": "use-gfx908",
"hidden": true,
"cacheVariables": {
"GPU_TARGETS": "gfx908"
}
},
{
"name": "use-gfx90a",
"hidden": true,
"cacheVariables": {
"GPU_TARGETS": "gfx90a"
}
},
{
"name": "use-gfx942",
"hidden": true,
"cacheVariables": {
"GPU_TARGETS": "gfx942"
}
},
{
"name": "use-gfx950",
"hidden": true,
"cacheVariables": {
"GPU_TARGETS": "gfx950"
}
},
{
"name": "dev",
"binaryDir": "${sourceDir}/build",
"displayName": "CK Dev",
"environment": {},
"cacheVariables": {
"CMAKE_PREFIX_PATH": "/opt/rocm/",
"CMAKE_CXX_COMPILER": "/opt/rocm/llvm/bin/clang++",
"CMAKE_HIP_COMPILER": "/opt/rocm/llvm/bin/clang++",
"CMAKE_CXX_FLAGS": "-ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker -fbracket-depth=1024",
"CMAKE_BUILD_TYPE": "Release",
"BUILD_DEV": "ON",
"CMAKE_VERBOSE_MAKEFILE": "ON",
"USE_BITINT_EXTENSION_INT4": "OFF",
"GPU_TARGETS": "gfx908;gfx90a;gfx942"
}
},
{
"name": "dev-gfx908",
"displayName": "CK Dev - gfx908",
"description": "Development build for AMD GPU gfx908",
"inherits": [
"use-gfx908",
"dev"
]
},
{
"name": "dev-gfx90a",
"displayName": "CK Dev - gfx90a",
"description": "Development build for AMD GPU gfx90a",
"inherits": [
"use-gfx90a",
"dev"
]
},
{
"name": "dev-gfx942",
"displayName": "CK Dev - gfx942",
"description": "Development build for AMD GPU gfx942",
"inherits": [
"use-gfx942",
"dev"
]
},
{
"name": "dev-gfx950",
"displayName": "CK Dev - gfx950",
"description": "Development build for AMD GPU gfx950",
"inherits": [
"use-gfx950",
"dev"
]
}
]
}

View File

@@ -1,7 +1,7 @@
FROM ubuntu:24.04
ARG DEBIAN_FRONTEND=noninteractive
ARG ROCMVERSION=7.0.1
ARG ROCMVERSION=7.1.1
ARG compiler_version=""
ARG compiler_commit=""
ARG CK_SCCACHE=""
@@ -13,8 +13,8 @@ ENV DEBIAN_FRONTEND=noninteractive
RUN set -xe && \
apt-get update && apt-get install -y --allow-unauthenticated apt-utils wget gnupg2 curl
RUN wget https://repo.radeon.com/amdgpu-install/7.0.1/ubuntu/noble/amdgpu-install_7.0.1.70001-1_all.deb && \
apt install ./amdgpu-install_7.0.1.70001-1_all.deb -y && \
RUN wget https://repo.radeon.com/amdgpu-install/7.1.1/ubuntu/noble/amdgpu-install_7.1.1.70101-1_all.deb && \
apt install ./amdgpu-install_7.1.1.70101-1_all.deb -y && \
apt update && \
apt install python3-setuptools python3-wheel -y && \
apt install rocm-dev -y

View File

@@ -2,7 +2,7 @@ ARG BASE_DOCKER="rocm/pytorch:latest"
FROM $BASE_DOCKER
ARG AITER_BRANCH="main"
ARG CK_AITER_BRANCH="develop"
RUN pip install pandas zmq einops ninja && \
RUN pip install pandas zmq einops ninja tabulate && \
pip install numpy==1.26.2 && \
sudo mkdir /home/jenkins && \
sudo mkdir /home/jenkins/workspace && \

View File

@@ -1,4 +1,4 @@
ARG BASE_DOCKER="rocm/composable_kernel:ck_ub24.04_rocm7.0.1"
ARG BASE_DOCKER="rocm/composable_kernel:ck_ub24.04_rocm7.1.1"
FROM $BASE_DOCKER
ARG compiler_version=""
ARG compiler_commit=""

101
Dockerfile.manylinux Normal file
View File

@@ -0,0 +1,101 @@
FROM ghcr.io/rocm/therock_build_manylinux_x86_64:latest
ARG DEBIAN_FRONTEND=noninteractive
ARG ROCMVERSION=7.2
ARG compiler_version=""
ARG compiler_commit=""
ARG CK_SCCACHE=""
ARG DEB_ROCM_REPO=http://repo.radeon.com/rocm/apt/.apt_$ROCMVERSION/
ENV APT_KEY_DONT_WARN_ON_DANGEROUS_USAGE=DontWarn
ENV DEBIAN_FRONTEND=noninteractive
USER root
# Add rocm repository
RUN dnf clean all && dnf update -y && dnf -v install wget gnupg2 curl -y
RUN wget https://repo.radeon.com/amdgpu-install/7.2/rhel/8.10/amdgpu-install-7.2.70200-1.el8.noarch.rpm && \
dnf install ./amdgpu-install-7.2.70200-1.el8.noarch.rpm -y && \
dnf update -y && \
dnf install python3-setuptools python3-wheel -y && \
dnf install rocm-dev -y
## Sccache binary built from source for ROCm, only install if CK_SCCACHE is defined
ARG SCCACHE_REPO_URL=http://compute-artifactory.amd.com/artifactory/rocm-generic-experimental/rocm-sccache
ENV SCCACHE_INSTALL_LOCATION=/usr/local/.cargo/bin
ENV PATH=$PATH:${SCCACHE_INSTALL_LOCATION}
ENV CK_SCCACHE=$CK_SCCACHE
RUN if [ "$CK_SCCACHE" != "" ]; then \
mkdir -p ${SCCACHE_INSTALL_LOCATION} && \
curl ${SCCACHE_REPO_URL}/portable/0.2.16/sccache-0.2.16-alpha.1-rocm --output ${SCCACHE_INSTALL_LOCATION}/sccache && \
chmod +x ${SCCACHE_INSTALL_LOCATION}/sccache; \
fi
# Install dependencies
RUN dnf update -y && DEBIAN_FRONTEND=noninteractive dnf install -y \
cmake \
clang-tools-extra \
gcc-c++ \
libstdc++ \
libstdc++-devel \
libstdc++-static \
git \
hip-rocclr \
jq \
mpich \
net-tools \
pkg-config \
redis \
sshpass \
stunnel \
vim \
nano \
zip \
openssh-server \
kmod && \
dnf clean all && \
rm -rf /var/lib/apt/lists/* && \
rm -rf amdgpu-install* && \
#Install latest ccache
git clone https://github.com/ccache/ccache.git && \
cd ccache && mkdir build && cd build && cmake .. && make install && \
#Install ClangBuildAnalyzer
git clone https://github.com/aras-p/ClangBuildAnalyzer.git && \
cd ClangBuildAnalyzer/ && \
make -f projects/make/Makefile && \
cd / && \
#Install latest cppcheck
git clone https://github.com/danmar/cppcheck.git && \
cd cppcheck && mkdir build && cd build && cmake .. && cmake --build . && \
cd / && \
# Install packages for processing the performance results
pip3 install --break-system-packages --upgrade pytest pymysql pandas==2.2.3 sqlalchemy==2.0.3 setuptools-rust setuptools sshtunnel==0.4.0 && \
# Add render group
groupadd -f render && \
# Install the new rocm-cmake version
git clone -b master https://github.com/ROCm/rocm-cmake.git && \
cd rocm-cmake && mkdir build && cd build && \
cmake .. && cmake --build . && cmake --build . --target install
WORKDIR /
# Add alternative compilers, if necessary
ENV compiler_version=$compiler_version
ENV compiler_commit=$compiler_commit
RUN sh -c "echo compiler version = '$compiler_version'" && \
sh -c "echo compiler commit = '$compiler_commit'"
RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline" ] ) && [ "$compiler_commit" = "" ]; then \
git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \
cd llvm-project && mkdir build && cd build && \
cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \
make -j 8 ; \
else echo "using the release compiler"; \
fi
RUN if ( [ "$compiler_version" = "amd-staging" ] || [ "$compiler_version" = "amd-mainline" ] ) && [ "$compiler_commit" != "" ]; then \
git clone -b "$compiler_version" https://github.com/ROCm/llvm-project.git && \
cd llvm-project && git checkout "$compiler_commit" && echo "checking out commit $compiler_commit" && mkdir build && cd build && \
cmake -DCMAKE_INSTALL_PREFIX=/opt/rocm/llvm -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=1 -DLLVM_TARGETS_TO_BUILD="AMDGPU;X86" -DLLVM_ENABLE_PROJECTS="clang;lld" -DLLVM_ENABLE_RUNTIMES="compiler-rt" ../llvm && \
make -j 8 ; \
else echo "using the release compiler"; \
fi

View File

@@ -20,4 +20,13 @@ RUN groupadd -g 109 render && \
git clone -b "$CK_PYTORCH_BRANCH" https://github.com/ROCm/composable_kernel.git && \
chown -R jenkins:jenkins /tmp/pytorch && \
chmod -R a+rwx /tmp/pytorch && \
sudo usermod -aG irc jenkins
sudo usermod -aG irc jenkins && \
#install hipblaslt
git clone --no-checkout --filter=blob:none https://github.com/ROCm/rocm-libraries.git && \
cd rocm-libraries && \
git checkout develop && \
git sparse-checkout init --cone && \
git sparse-checkout set projects/hipblaslt shared/origami && \
cd projects/hipblaslt && \
git show --oneline -s && \
CPLUS_INCLUDE_PATH="/opt/amdgpu/include/" ./install.sh -idc --architecture="gfx942;gfx950" -j 128 --skip_rocroller

645
Jenkinsfile vendored

File diff suppressed because it is too large Load Diff

View File

@@ -137,6 +137,22 @@ Docker images are available on [DockerHub](https://hub.docker.com/r/rocm/composa
```
**[See Note on -j](#notes)**
### Building for Windows
Install TheRock and run CMake configure as
```bash
cmake \
-D CMAKE_PREFIX_PATH="C:/dist/TheRock" \
-D CMAKE_CXX_COMPILER="C:/dist/TheRock/bin/hipcc.exe" \
-D CMAKE_BUILD_TYPE=Release \
-D GPU_TARGETS="gfx1151" \
-G Ninja \
..
```
Use Ninja to build either the whole library or individual targets.
## Optional post-install steps
* Build examples and tests:
@@ -187,7 +203,7 @@ limit the number of threads. For example, if you have a 128-core CPU and 128 Gb
Additional cmake flags can be used to significantly speed-up the build:
* `DTYPES` (default is not set) can be set to any subset of "fp64;fp32;fp16;fp8;bf16;int8" to build
* `DTYPES` (default is not set) can be set to any subset of "fp64;fp32;tf32;fp16;fp8;bf16;int8" to build
instances of select data types only. The main default data types are fp32 and fp16; you can safely skip
other data types.

View File

@@ -1,2 +1,5 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
add_executable(client_gemm gemm.cpp)
target_link_libraries(client_gemm PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations)

View File

@@ -1,3 +1,6 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
if(GPU_TARGETS MATCHES "gfx9")
add_custom_target(client_gemm_fastgelu_examples)

View File

@@ -1,3 +1,6 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
if(GPU_TARGETS MATCHES "gfx9")
add_executable(client_gemm_add_add_layernorm_naive gemm_add_add_layernorm_naive.cpp)
target_link_libraries(client_gemm_add_add_layernorm_naive PRIVATE composable_kernel::device_gemm_operations composable_kernel::device_other_operations)

View File

@@ -1,3 +1,6 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
if(GPU_TARGETS MATCHES "gfx9")
add_executable(client_contraction_scale_fp32 contraction_scale_fp32.cpp)
target_link_libraries(client_contraction_scale_fp32 PRIVATE composable_kernel::device_other_operations composable_kernel::device_contraction_operations composable_kernel::device_gemm_operations)

View File

@@ -1,3 +1,6 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
add_executable(client_layernorm2d_bwd_data layernorm2d_bwd_data.cpp)
target_link_libraries(client_layernorm2d_bwd_data PRIVATE composable_kernel::device_other_operations)

View File

@@ -1,2 +1,5 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
add_executable(client_softmax4d softmax4d.cpp)
target_link_libraries(client_softmax4d PRIVATE composable_kernel::device_other_operations composable_kernel::device_reduction_operations)

View File

@@ -1,3 +1,6 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
if(GPU_TARGETS MATCHES "gfx9")
add_executable(client_grouped_conv2d_fwd grouped_conv2d_fwd.cpp)
target_link_libraries(client_grouped_conv2d_fwd PRIVATE composable_kernel::device_conv_operations)

View File

@@ -1,3 +1,6 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
if(GPU_TARGETS MATCHES "gfx9")
add_executable(client_fused_attention fused_attention.cpp)
target_link_libraries(client_fused_attention PRIVATE composable_kernel::device_other_operations composable_kernel::device_gemm_operations)

View File

@@ -1,3 +1,6 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
if(GPU_TARGETS MATCHES "gfx9" AND (DTYPES MATCHES "int8" OR NOT DEFINED DTYPES))
add_executable(client_conv2d_fwd_bias_tanh_perchannel_quantization conv2d_fwd_bias_tanh_perchannel_quantization.cpp)
target_link_libraries(client_conv2d_fwd_bias_tanh_perchannel_quantization PRIVATE composable_kernel::device_conv_operations composable_kernel::device_other_operations composable_kernel::device_gemm_operations)

View File

@@ -1,3 +1,6 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
add_executable(client_grouped_conv2d_bwd_data grouped_conv2d_bwd_data.cpp)
target_link_libraries(client_grouped_conv2d_bwd_data PRIVATE composable_kernel::device_conv_operations)

View File

@@ -1,3 +1,6 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
add_executable(client_grouped_conv1d_bwd_weight_fp16 grouped_conv1d_bwd_weight_fp16.cpp)
add_executable(client_grouped_conv2d_bwd_weight_fp16 grouped_conv2d_bwd_weight_fp16.cpp)
add_executable(client_grouped_conv3d_bwd_weight_fp16 grouped_conv3d_bwd_weight_fp16.cpp)

View File

@@ -1,2 +1,5 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
add_executable(client_elementwise_layernorm2d elementwise_layernorm2d.cpp)
target_link_libraries(client_elementwise_layernorm2d PRIVATE composable_kernel::device_other_operations)

View File

@@ -1,3 +1,6 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
add_executable(client_batchnorm_fwd_nhwc batchnorm_fwd_nhwc.cpp)
add_executable(client_batchnorm_bwd_nhwc batchnorm_bwd_nhwc.cpp)
add_executable(client_batchnorm_infer_nhwc batchnorm_infer_nhwc.cpp)

View File

@@ -1,2 +1,5 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
add_executable(client_batchnorm_fwd_instance_id batchnorm_fwd_instance_id.cpp)
target_link_libraries(client_batchnorm_fwd_instance_id PRIVATE composable_kernel::device_other_operations)

View File

@@ -1,3 +1,6 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
if(GPU_TARGETS MATCHES "gfx9")
add_executable(client_conv3d_bwd_data_fp16 conv3d_bwd_data_fp16.cpp)
add_executable(client_conv3d_bwd_data_fp32 conv3d_bwd_data_fp32.cpp)

View File

@@ -1,3 +1,6 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
if((DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)
add_executable(client_conv3d_fwd_fp16 conv3d_fwd_fp16.cpp)
target_link_libraries(client_conv3d_fwd_fp16 PRIVATE composable_kernel::device_conv_operations)

View File

@@ -1,3 +1,6 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
if(GPU_TARGETS MATCHES "gfx9")
add_executable(client_grouped_gemm_fastgelu grouped_gemm_fastgelu.cpp)
target_link_libraries(client_grouped_gemm_fastgelu PRIVATE composable_kernel::device_gemm_operations)

View File

@@ -1,3 +1,6 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
add_executable(client_groupnorm_bwd_data groupnorm_bwd_data.cpp)
target_link_libraries(client_groupnorm_bwd_data PRIVATE composable_kernel::device_other_operations)

View File

@@ -1,3 +1,6 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
add_executable(client_max_pool2d_fwd max_pool2d_fwd.cpp)
target_link_libraries(client_max_pool2d_fwd PRIVATE composable_kernel::device_other_operations)

View File

@@ -1,3 +1,6 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94"))
add_executable(client_splitK_gemm splitK_gemm_fp16_f8.cpp)
target_link_libraries(client_splitK_gemm PRIVATE composable_kernel::device_gemm_operations)

View File

@@ -1,3 +1,6 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
if(GPU_TARGETS MATCHES "gfx9")
add_executable(client_grouped_gemm_fixed_nk_bias_fp16 grouped_gemm_fixed_nk_bias_fp16.cpp)
target_link_libraries(client_grouped_gemm_fixed_nk_bias_fp16 PRIVATE composable_kernel::device_gemm_operations)

View File

@@ -1,3 +1,6 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
if(GPU_TARGETS MATCHES "gfx9")
add_executable(client_grouped_gemm_fixed_nk_fp16 grouped_gemm_fixed_nk_fp16.cpp)
target_link_libraries(client_grouped_gemm_fixed_nk_fp16 PRIVATE composable_kernel::device_gemm_operations)

View File

@@ -1,2 +1,5 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
add_executable(client_elementwise_transpose3d elementwise_transpose_3d.cpp)
target_link_libraries(client_elementwise_transpose3d PRIVATE composable_kernel::device_other_operations)

View File

@@ -1,3 +1,6 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
if(GPU_TARGETS MATCHES "gfx9")
# Fwd scaleadd scaleadd relu
add_executable(client_grouped_convnd_fwd_scaleadd_scaleadd_relu_fp32

View File

@@ -1,3 +1,6 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
add_executable(client_tensor_transform_using_wrapper tensor_transform_using_wrapper.cpp)
target_link_libraries(client_tensor_transform_using_wrapper PRIVATE composable_kernel::device_other_operations)
add_executable(client_wrapper_img2col wrapper_img2col.cpp)

View File

@@ -1,2 +1,5 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
add_executable(client_reduce_nhwc_c reduce_nhwc_c.cpp)
target_link_libraries(client_reduce_nhwc_c PRIVATE composable_kernel::device_reduction_operations)

View File

@@ -1,3 +1,6 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
add_executable(client_image_to_column image_to_column.cpp)
target_link_libraries(client_image_to_column PRIVATE composable_kernel::device_other_operations)

View File

@@ -1,3 +1,6 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
if(GPU_TARGETS MATCHES "gfx950")
add_executable(client_gemm_mx_fp8 gemm_mx_fp8.cpp)
target_link_libraries(client_gemm_mx_fp8 PRIVATE composable_kernel::device_gemm_operations)

View File

@@ -1,3 +1,6 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
if(GPU_TARGETS MATCHES "gfx9")
add_executable(client_gemm_add_multiply gemm_add_multiply.cpp)
target_link_libraries(client_gemm_add_multiply PRIVATE composable_kernel::device_gemm_operations)

View File

@@ -1,3 +1,6 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
if(GPU_TARGETS MATCHES "gfx9" AND ((DTYPES MATCHES "int8" AND DTYPES MATCHES "bf16") OR NOT DEFINED DTYPES))
add_executable(client_gemm_bias_fastgelu_bf16_i8_bf16 gemm_bias_fastgelu_xdl_bf16_i8.cpp)
target_link_libraries(client_gemm_bias_fastgelu_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations)

View File

@@ -1,3 +1,6 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
if(GPU_TARGETS MATCHES "gfx9" AND ((DTYPES MATCHES "int8" AND DTYPES MATCHES "bf16") OR NOT DEFINED DTYPES))
add_executable(client_grouped_gemm_bias_fastgelu_bf16_i8_bf16 grouped_gemm_bias_fastgelu_xdl_bf16_i8.cpp)
target_link_libraries(client_grouped_gemm_bias_fastgelu_bf16_i8_bf16 PRIVATE composable_kernel::device_gemm_operations)

View File

@@ -1,3 +1,6 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
cmake_minimum_required(VERSION 3.15)
project(ck_app)
add_compile_options(-std=c++20)
@@ -24,6 +27,9 @@ if (DTYPES)
add_definitions(-DCK_ENABLE_FP32)
set(CK_ENABLE_FP32 "ON")
endif()
if (DTYPES MATCHES "tf32")
set(CK_ENABLE_TF32 "ON")
endif()
if (DTYPES MATCHES "fp64")
add_definitions(-DCK_ENABLE_FP64)
set(CK_ENABLE_FP64 "ON")
@@ -38,6 +44,7 @@ else()
set(CK_ENABLE_INT8 "ON")
set(CK_ENABLE_FP16 "ON")
set(CK_ENABLE_FP32 "ON")
set(CK_ENABLE_TF32 "ON")
set(CK_ENABLE_FP64 "ON")
set(CK_ENABLE_BF16 "ON")
if (GPU_TARGETS MATCHES "gfx94")
@@ -64,6 +71,14 @@ if (GPU_TARGETS)
add_definitions(-DCK_USE_FNUZ_FP8)
set(CK_USE_FNUZ_FP8 "ON")
endif()
if ((GPU_TARGETS MATCHES "gfx942" OR GPU_TARGETS MATCHES "gfx95") AND CK_ENABLE_TF32)
add_definitions(-DCK_ENABLE_TF32)
set(CK_ENABLE_TF32 "ON")
else()
message(STATUS "Disabling TF32 instances for this target")
remove_definitions(-DCK_ENABLE_TF32)
set(CK_ENABLE_TF32 "OFF")
endif()
else()
add_definitions(-DCK_USE_WMMA -DCK_USE_XDL)
set(CK_USE_XDL "ON")

View File

@@ -35,7 +35,7 @@ function(generate_sharded_instantiations)
set(GENERATED_SOURCE_FILES "")
set(EXTERN_TEMPLATE_STATEMENTS "")
set(CALL_STATEMENTS "")
message(STATUS "Generating sharded instantiations for target: ${GEN_SHARDED_INSTANCES_NAME}")
message(DEBUG "Generating sharded instantiations for target: ${GEN_SHARDED_INSTANCES_NAME}")
set(INSTANCES "${GEN_SHARDED_INSTANCES_NAME}")

View File

@@ -68,6 +68,9 @@ set(GTEST_CXX_FLAGS
-Wno-deprecated
-Wno-unsafe-buffer-usage
-Wno-float-equal
-Wno-lifetime-safety-intra-tu-suggestions
-Wno-lifetime-safety-cross-tu-suggestions
-Wno-character-conversion
)
if(WIN32)

View File

@@ -1,3 +1,6 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
cmake_minimum_required(VERSION 3.16)
project(composable_kernel_host)
@@ -12,6 +15,7 @@ configure_file(${CK_ROOT}/include/ck/config.h.in ${CK_ROOT}/include/ck/config.h)
find_package(ROCM)
include(ROCMInstallTargets)
include(ROCMTest)
list(APPEND CMAKE_PREFIX_PATH /opt/rocm $ENV{ROCM_PATH})
find_package(hiprtc REQUIRED)
rocm_setup_version(VERSION 1.0)

View File

@@ -1,3 +1,6 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
list(APPEND CMAKE_PREFIX_PATH /opt/rocm)
add_subdirectory(rtc)
file(GLOB TEST_SRCS CONFIGURE_DEPENDS *.cpp)

View File

@@ -1,3 +1,6 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
find_package(hip)
file(GLOB RTC_SOURCES CONFIGURE_DEPENDS src/*.cpp)
add_library(ck_rtc ${RTC_SOURCES})

117
dispatcher/CMakeLists.txt Normal file
View File

@@ -0,0 +1,117 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
cmake_minimum_required(VERSION 3.16)
project(ck_tile_dispatcher VERSION 1.0.0 LANGUAGES CXX)
# C++17 required
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)
# Find HIP for headers (needed for validation kernels)
find_package(hip QUIET)
if(NOT hip_FOUND)
list(APPEND CMAKE_PREFIX_PATH /opt/rocm /opt/rocm/hip)
find_package(hip REQUIRED)
endif()
# Dispatcher library
add_library(ck_tile_dispatcher
src/registry.cpp
src/dispatcher.cpp
)
# Enable PIC for Python bindings
set_target_properties(ck_tile_dispatcher PROPERTIES
POSITION_INDEPENDENT_CODE ON
)
target_include_directories(ck_tile_dispatcher
PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
$<INSTALL_INTERFACE:include>
)
# Link against CK Tile headers (header-only)
target_include_directories(ck_tile_dispatcher
PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/../include>
$<INSTALL_INTERFACE:include>
)
# Link against HIP headers if available
if(hip_FOUND)
target_link_libraries(ck_tile_dispatcher PUBLIC hip::host)
endif()
# Compiler warnings
if(CMAKE_CXX_COMPILER_ID MATCHES "GNU|Clang")
target_compile_options(ck_tile_dispatcher PRIVATE
-Wall -Wextra -Wpedantic
)
elseif(CMAKE_CXX_COMPILER_ID MATCHES "MSVC")
target_compile_options(ck_tile_dispatcher PRIVATE
/W4
)
endif()
# Optional: Build tests
option(BUILD_DISPATCHER_TESTS "Build dispatcher unit tests" OFF)
if(BUILD_DISPATCHER_TESTS)
enable_testing()
add_subdirectory(tests)
endif()
# Optional: Build Python bindings
option(BUILD_DISPATCHER_PYTHON "Build Python bindings for dispatcher" OFF)
if(BUILD_DISPATCHER_PYTHON)
add_subdirectory(python)
endif()
# Optional: Codegen for tile_engine integration
option(DISPATCHER_AUTO_GENERATE_WRAPPERS "Auto-generate wrappers from tile_engine" OFF)
if(DISPATCHER_AUTO_GENERATE_WRAPPERS)
add_subdirectory(codegen)
endif()
# Optional: Build examples
option(BUILD_DISPATCHER_EXAMPLES "Build dispatcher examples" OFF)
if(BUILD_DISPATCHER_EXAMPLES)
add_subdirectory(examples)
endif()
# Optional: Build ctypes bindings
option(BUILD_DISPATCHER_BINDINGS "Build language bindings for dispatcher" OFF)
if(BUILD_DISPATCHER_BINDINGS)
add_subdirectory(bindings/ctypes)
endif()
# If codegen is enabled, add generated include directory
if(DISPATCHER_AUTO_GENERATE_WRAPPERS AND DISPATCHER_GENERATED_INCLUDE_DIR)
target_include_directories(ck_tile_dispatcher
PUBLIC
$<BUILD_INTERFACE:${DISPATCHER_GENERATED_INCLUDE_DIR}>
)
endif()
# Installation
install(TARGETS ck_tile_dispatcher
EXPORT ck_tile_dispatcher_targets
LIBRARY DESTINATION lib
ARCHIVE DESTINATION lib
RUNTIME DESTINATION bin
)
install(DIRECTORY include/
DESTINATION include
FILES_MATCHING PATTERN "*.hpp"
)
install(EXPORT ck_tile_dispatcher_targets
FILE ck_tile_dispatcher_targets.cmake
NAMESPACE ck_tile::
DESTINATION lib/cmake/ck_tile_dispatcher
)

736
dispatcher/README.md Normal file
View File

@@ -0,0 +1,736 @@
# CK Tile Dispatcher
A unified kernel dispatch system for AMD GPUs with C++ and Python frontends.
**Validated Platform:** AMD Instinct MI300 series (gfx942)
---
## Table of Contents
1. [Quick Start](#quick-start)
2. [Docker Setup](#docker-setup-recommended)
3. [Prerequisites](#prerequisites)
4. [Step-by-Step Build Guide](#step-by-step-build-guide)
5. [Running Examples](#running-examples)
6. [External Integration](#external-integration)
7. [Core Concepts](#core-concepts)
8. [Troubleshooting](#troubleshooting)
9. [File Structure](#file-structure)
---
## Quick Start
**Complete setup from scratch (5 minutes):**
```bash
# From the composable_kernel root directory
cd dispatcher
# Step 1: Create build directory
mkdir -p build && cd build
# Step 2: Configure CMake
cmake .. \
-DCMAKE_PREFIX_PATH=/opt/rocm \
-DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-DCMAKE_BUILD_TYPE=Release \
-DGPU_TARGETS="gfx942" \
-DBUILD_DISPATCHER_EXAMPLES=ON
# Step 3: Generate kernels and build (CMake handles this automatically)
make -j$(nproc)
# Step 4: Run C++ examples
./examples/gemm_01_basic
# Step 5: Build Python libraries (required for Python examples)
make python_libs
# Step 6: Run Python examples (from dispatcher directory)
cd ..
python3 examples/gemm/python/01_basic_gemm.py
```
---
## Docker Setup (Recommended)
For a reproducible build environment, use the official ROCm Docker image:
### Step 1: Pull and Run Container
```bash
# Pull the CK Docker image
docker pull rocm/composable_kernel:ck_ub24.04_rocm7.0.1
# Run container with GPU access
docker run \
-it \
--privileged \
--device=/dev/kfd \
--device=/dev/dri \
--group-add video \
--group-add render \
-w /root/workspace \
-v $(pwd):/root/workspace \
rocm/composable_kernel:ck_ub24.04_rocm7.0.1 \
/bin/bash
```
> **Note:** Omit `--device` flags if building without GPU access.
### Step 2: Clone and Build
```bash
# Inside the container
git clone https://github.com/ROCm/composable_kernel.git
cd composable_kernel
git checkout builder-dispatch-tile-gemm
# Set up Python environment
python3 -m venv .venv
source .venv/bin/activate
pip install numpy
# Build dispatcher
cd dispatcher
mkdir -p build && cd build
cmake .. \
-DCMAKE_PREFIX_PATH=/opt/rocm \
-DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-DCMAKE_BUILD_TYPE=Release \
-DGPU_TARGETS="gfx942" \
-DBUILD_DISPATCHER_EXAMPLES=ON
make -j$(nproc)
```
### One-Liner Build (inside container)
```bash
git clone https://github.com/ROCm/composable_kernel.git && \
cd composable_kernel && git checkout builder-dispatch-tile-gemm && \
python3 -m venv .venv && source .venv/bin/activate && pip install numpy && \
cd dispatcher && mkdir -p build && cd build && \
cmake .. -DCMAKE_PREFIX_PATH=/opt/rocm -DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-DCMAKE_BUILD_TYPE=Release -DGPU_TARGETS="gfx942" -DBUILD_DISPATCHER_EXAMPLES=ON && \
make -j$(nproc)
```
---
## Prerequisites
### Required Software
| Software | Minimum Version | Check Command |
|----------|-----------------|---------------|
| ROCm | 6.4+ | `rocminfo` |
| CMake | 3.16+ | `cmake --version` |
| Python | 3.8+ | `python3 --version` |
| NumPy | 1.20+ | `pip show numpy` |
| hipcc | (from ROCm) | `/opt/rocm/bin/hipcc --version` |
> **Note:** Newer GPU targets (gfx950, gfx1201) require ROCm 6.3+. For ROCm 6.4+, you can also use `amdclang++` instead of `hipcc`.
### Check Your GPU Architecture
```bash
# Find your GPU architecture
rocminfo | grep -i "gfx"
# Example output: "gfx942"
```
**Supported architectures:**
- **gfx942** - MI300X, MI300A, MI308, MI325 (Instinct MI300 series)
- **gfx90a** - MI200 series (MI250, MI250X)
- **gfx950** - MI350 series
- **gfx1101** - RDNA3 series
- **gfx1201** - RDNA4 series
### Install Python Dependencies
NumPy is required for Python examples and kernel generation. We recommend using a virtual environment:
**Option 1: Using standard venv**
```bash
# Create virtual environment
python3 -m venv .venv
# Activate virtual environment
source .venv/bin/activate # Linux/macOS
# .venv\Scripts\activate # Windows
# Install NumPy
pip install numpy
```
**Option 2: Using uv (faster alternative)**
```bash
# Install uv if not already installed
curl -LsSf https://astral.sh/uv/install.sh | sh
# Create and activate virtual environment
uv venv .venv
source .venv/bin/activate # Linux/macOS
# .venv\Scripts\activate # Windows
# Install NumPy
uv pip install numpy
```
**Option 3: System-wide install (not recommended)**
```bash
pip install numpy
```
> **Note:** Always activate your virtual environment before running CMake or Python examples.
### Supported Data Types
CK Tile supports a wide range of data types for GEMM operations:
| A dtype | B dtype | Acc dtype | Warp Tile Sizes | Notes |
|---------|---------|-----------|-----------------|-------|
| `fp32` | `fp32` | `fp32` | 16x16x4, 16x16x16 | Full precision |
| `fp16` | `fp16` | `fp32` | 32x32x8, 32x32x16, 16x16x16, 16x16x32 | Standard half |
| `bf16` | `bf16` | `fp32` | 32x32x8, 32x32x16, 16x16x16, 16x16x32 | Brain float 16 |
| `fp8` | `fp8` | `fp32` | 32x32x16, 32x32x32, 16x16x32, 16x16x64 | FP8 E4M3 |
| `fp8` | `bf8` | `fp32` | 32x32x16, 16x16x32 | Mixed FP8/BF8 |
| `bf8` | `fp8` | `fp32` | 32x32x16, 16x16x128 | Mixed BF8/FP8 |
| `bf8` | `bf8` | `fp32` | 32x32x16, 32x32x32, 16x16x32 | BF8 E5M2 |
| `int8` | `int8` | `int32` | 32x32x16, 16x16x32, 16x16x16 | Integer GEMM |
| `pk_fp4` | `pk_fp4` | `fp32` | 16x16x128 | Packed 4-bit float |
**Notes:**
- Accumulator is always `fp32` except for `int8` which uses `int32`
- FP8 types: `fp8` = E4M3, `bf8` = E5M2
- `pk_fp4` = Packed 4-bit float (2 values per byte)
- Some dtypes require specific GPU architectures (e.g., FP8 requires MI300+)
---
## Step-by-Step Build Guide
### Step 1: Navigate to Dispatcher Directory
```bash
# From composable_kernel root
cd dispatcher
# Verify you're in the right place
ls CMakeLists.txt # Should exist
```
### Step 2: Create Build Directory
```bash
mkdir -p build
cd build
```
### Step 3: Configure CMake
**Basic configuration (library only):**
```bash
cmake .. \
-DCMAKE_PREFIX_PATH=/opt/rocm \
-DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-DCMAKE_BUILD_TYPE=Release \
-DGPU_TARGETS="gfx942"
```
**Full configuration (with examples and tests):**
```bash
cmake .. \
-DCMAKE_PREFIX_PATH=/opt/rocm \
-DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-DCMAKE_BUILD_TYPE=Release \
-DGPU_TARGETS="gfx942" \
-DBUILD_DISPATCHER_EXAMPLES=ON \
-DBUILD_DISPATCHER_TESTS=ON
```
**Expected output:**
```
-- Found hip: /opt/rocm (found suitable version "6.x.x")
-- Generating GEMM kernels...
-- Built: gemm_01 through gemm_06, dispatcher_gemm_lib.so
-- Configuring done
```
### Step 4: Build
```bash
# Build all targets (generates kernels automatically, then compiles)
make -j$(nproc)
# Or build specific targets
make gemm_01_basic # Single GEMM example
make dispatcher_gemm_lib # GEMM shared library for Python
# Build ONLY Python libraries (faster if you don't need C++ examples)
make python_libs -j$(nproc)
```
### Kernel Generation Targets
Kernels are generated automatically during `make`, but you can also control generation explicitly:
```bash
# Generate all kernels only (no compilation)
make generate_all_kernels
# Generate GEMM kernels only
make generate_gemm_kernels
# Force regenerate (even if kernels exist)
make regenerate_all_kernels
make regenerate_gemm_kernels
# Generate for specific GPU architecture
make generate_kernels_gfx942 # MI300X
make generate_kernels_gfx90a # MI200
make generate_kernels_gfx1100 # RDNA3
```
### Step 5: Verify Build
```bash
# Check executables were built
ls examples/gemm_*
# Check shared libraries were built
ls examples/libdispatcher_gemm_lib.so
```
### CMake Options Reference
| Flag | Default | Description |
|------|---------|-------------|
| `CMAKE_BUILD_TYPE` | Debug | **Use `Release` for performance!** |
| `GPU_TARGETS` | None | Target GPU: `"gfx942"`, `"gfx90a"`, etc. |
| `BUILD_DISPATCHER_EXAMPLES` | OFF | Build C++ examples and Python libs |
| `BUILD_DISPATCHER_TESTS` | OFF | Build unit tests |
| `CMAKE_PREFIX_PATH` | - | ROCm installation path |
| `CMAKE_CXX_COMPILER` | - | Path to hipcc compiler |
⚠️ **Important:** Always use `-DCMAKE_BUILD_TYPE=Release` for benchmarking. Debug builds are slower.
⚠️ **Important:** Note that the current system provides single GPU target support for architecture-based kernel filtering, please do not use multiple GPU targets at a time (if necessary, please compile into different build directories).
---
## Running Examples
### C++ Examples
After building, executables are in `build/examples/`:
```bash
cd build/examples
# GEMM Examples
./gemm_01_basic # Basic GEMM with autofill/autocorrect
./gemm_02_multi_size # Wildcard expansion
./gemm_03_benchmark_validation # Benchmarking + validation
./gemm_04_heuristics # Heuristic kernel selection
./gemm_05_json_export # Registry JSON export
./gemm_06_multi_registry # Multiple registries
```
### Python Examples
Run from the `dispatcher` directory:
```bash
cd /path/to/composable_kernel/dispatcher
# GEMM Examples
python3 examples/gemm/python/01_basic_gemm.py # Basic multi-kernel GEMM
python3 examples/gemm/python/04_validation.py # CPU reference validation
python3 examples/gemm/python/07_stress_test.py # Stress test (48 kernels)
python3 examples/gemm/python/08_heuristics.py # Heuristic selection
```
### Example Output
**Expected C++ output (`gemm_01_basic`):**
```
======================================================================
Example 01: Basic GEMM with Declarative Kernel Definition
======================================================================
Step 1: Declared Kernels
------------------------
Kernel Set: fp16_gemm_kernels
Architecture: gfx942
Configurations: 1
- gemm_fp16_rcr_compv4_cshuffle_intrawave_128x128x32
Step 2: Create Registry and Dispatcher
--------------------------------------
Registered 1 kernels
Step 3: Define Problem
----------------------
M=1024, N=1024, K=1024
Step 4: GPU Execution
---------------------
*** GPU EXECUTION ***
Time: <varies> ms
TFLOPS: <varies>
```
> **Note:** Timing values vary by GPU model and system configuration.
---
## Benchmark Parameters
The dispatcher supports fine-grained control over benchmarking, matching CK Tile's `stream_config`:
### Available Parameters
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `warmup` | int | 5 | Warmup iterations (discarded from timing) |
| `repeat` | int | 20 | Benchmark iterations (averaged) |
| `flush_cache` | bool | false | Flush GPU L2 cache between iterations |
| `rotating_count` | int | 1 | Rotating buffer count (for cache simulation) |
| `timer` | string | "gpu" | Timer type: "gpu" (HIP events) or "cpu" |
| `init` | string | "random" | Matrix initialization: "random", "linear", "constant" |
| `split_k` | int | 1 | Split-K parallelism factor |
### Python Usage
```python
from ctypes_utils import DispatcherLib
# Basic usage (default benchmark settings)
lib = DispatcherLib.load()
# Advanced benchmark settings via command line
python3 examples/gemm/python/10_advanced_benchmark.py \
--warmup 10 \
--repeat 100 \
--flush-cache
```
### C++ Usage
```cpp
// Basic timing
ck_tile::stream_config cfg{nullptr, true};
// Advanced benchmark settings
ck_tile::stream_config cfg{
nullptr, // stream_id (nullptr = default stream)
true, // time_kernel
1, // log_level
10, // cold_niters (warmup)
100, // nrepeat
true, // is_gpu_timer
true, // flush_cache
4 // rotating_count
};
float avg_time = kernel.run(args, cfg);
```
### Command Line (Python Examples)
```bash
# Basic run
python3 examples/gemm/python/10_advanced_benchmark.py
# With benchmark parameters
python3 examples/gemm/python/10_advanced_benchmark.py \
--warmup 10 \
--repeat 100 \
--flush-cache \
--rotating-count 4 \
--timer gpu
```
### When to Use Each Parameter
| Use Case | Recommended Settings |
|----------|---------------------|
| Quick test | `warmup=1, repeat=3` |
| Stable benchmark | `warmup=10, repeat=100` |
| Memory-bound analysis | `flush_cache=True, rotating_count=4` |
| Compute-bound analysis | `flush_cache=False` (default) |
| Debug timing | `timer="cpu"` |
| Production | `timer="gpu"` (default) |
---
## External Integration
### Using Dispatcher in Your Own Project
#### Option 1: CMake Integration (Recommended)
Add to your `CMakeLists.txt`:
```cmake
# Set path to composable_kernel
set(CK_ROOT "/path/to/composable_kernel")
# Add dispatcher subdirectory
add_subdirectory(${CK_ROOT}/dispatcher dispatcher_build)
# Link to your target
target_link_libraries(your_target PRIVATE ck_tile_dispatcher)
target_include_directories(your_target PRIVATE
${CK_ROOT}/dispatcher/include
${CK_ROOT}/include
)
```
#### Option 2: Include as Pre-built Library
```cmake
# Find the pre-built library
find_library(CK_DISPATCHER ck_tile_dispatcher
PATHS /path/to/composable_kernel/dispatcher/build)
# Include directories
set(CK_INCLUDE_DIRS
/path/to/composable_kernel/include
/path/to/composable_kernel/dispatcher/include
)
target_link_libraries(your_target PRIVATE ${CK_DISPATCHER})
target_include_directories(your_target PRIVATE ${CK_INCLUDE_DIRS})
```
#### Option 3: Python Integration
```python
import sys
sys.path.insert(0, "/path/to/composable_kernel/dispatcher/examples/gemm/python")
# For GEMM
from ctypes_utils import DispatcherLib, Dispatcher, KernelConfig
```
### Required Include Paths
When integrating, you need these include paths:
```
/path/to/composable_kernel/include # CK Tile core headers
/path/to/composable_kernel/dispatcher/include # Dispatcher headers
/path/to/composable_kernel/dispatcher/build/generated_kernels # Generated kernels
```
### Required Compile Flags
```bash
# Minimum flags for hipcc
-std=c++17
-D__HIP_PLATFORM_AMD__=1
--offload-arch=gfx942 # Your target GPU
# Recommended flags
-O3
-mllvm -enable-noalias-to-md-conversion=0
-Wno-undefined-func-template
-Wno-float-equal
-Wall
-Werror
```
### Python Path Setup
For Python scripts outside the dispatcher directory:
```bash
# Option 1: Environment variable
export PYTHONPATH="/path/to/composable_kernel/dispatcher/examples/gemm/python:$PYTHONPATH"
# Option 2: In your Python script
import sys
sys.path.insert(0, "/path/to/composable_kernel/dispatcher/examples/gemm/python")
```
### Library Search Paths
The Python utilities search for the shared library in these locations:
```python
# For GEMM (ctypes_utils.py)
SEARCH_PATHS = [
"build/examples/libdispatcher_gemm_lib.so",
"../build/examples/libdispatcher_gemm_lib.so",
"../../build/examples/libdispatcher_gemm_lib.so",
]
```
If using from a different location, set the library path explicitly:
```python
# GEMM
from ctypes_utils import DispatcherLib
lib = DispatcherLib.load("/absolute/path/to/libdispatcher_gemm_lib.so")
```
---
## Core Concepts
### Data Flow
```
KernelConfig → Registry → Dispatcher → GPU Execution
```
1. **KernelConfig**: Defines kernel parameters (tile sizes, data types, layouts)
2. **Registry**: Stores multiple kernel configurations
3. **Dispatcher**: Selects best kernel for a given problem and executes it
### GEMM Layouts
| Layout | A | B | C | Use Case |
|--------|---|---|---|----------|
| RCR | Row | Col | Row | Most common (PyTorch default) |
| RRR | Row | Row | Row | Both inputs row-major |
| CRR | Col | Row | Row | A transposed |
| CCR | Col | Col | Row | Both inputs column-major |
### Split-K Support
Split-K divides the K dimension across multiple thread blocks, useful for large K dimensions.
**Usage (C++):**
```cpp
// GEMM with 4-way K split
auto problem = ProblemBuilder()
.m(1024).n(1024).k(8192)
.split_k(4)
.build();
```
---
## Troubleshooting
### Build Issues
| Problem | Solution |
|---------|----------|
| `hipcc not found` | Set `-DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc` |
| `hip not found` | Set `-DCMAKE_PREFIX_PATH=/opt/rocm` |
| Very slow performance | Use `-DCMAKE_BUILD_TYPE=Release` |
| `gfx942 not supported` | Check ROCm version (need 6.0+) |
| Kernel generation fails | Ensure Python 3.8+ with NumPy installed in active venv |
| Build errors | First verify CK builds without dispatcher (see main CK README) |
### Runtime Issues
| Problem | Solution |
|---------|----------|
| `Library not found` | Build with `-DBUILD_DISPATCHER_EXAMPLES=ON` |
| `No kernel found` | Check GPU arch matches build target |
| Python `ModuleNotFoundError` | Add paths to `PYTHONPATH` (see above) |
| Wrong results | Verify layout matches your data |
### Debug Commands
```bash
# Check ROCm installation
rocminfo | head -20
# Check GPU architecture
rocminfo | grep "Name:"
# Verify library exists
ls -la build/examples/libdispatcher_*.so
# Run with verbose output
./build/examples/gemm_01_basic 2>&1
# Python: Check library loading
python3 -c "
import ctypes
lib = ctypes.CDLL('/path/to/libdispatcher_gemm_lib.so')
print('Library loaded successfully')
"
```
### Clean Rebuild
If you encounter issues, try a clean rebuild:
```bash
cd dispatcher
rm -rf build
mkdir build && cd build
cmake .. [your options]
make -j$(nproc)
```
---
## File Structure
```
dispatcher/
├── README.md # This file
├── CMakeLists.txt # Build configuration
├── include/ck_tile/dispatcher/ # C++ headers
│ ├── dispatcher.hpp # GEMM dispatcher
│ ├── registry.hpp # Kernel registry
│ └── kernel_key.hpp # Kernel configuration
├── src/ # C++ implementation
├── codegen/ # Kernel generation
│ ├── unified_gemm_codegen.py # GEMM kernel generator
│ └── arch_specs.json # GPU specifications
├── bindings/ctypes/ # Python ctypes interface
│ └── gemm_ctypes_lib.cpp # GEMM Python library
├── examples/ # Examples
│ └── gemm/
│ ├── cpp/ # C++ GEMM examples (01-06)
│ └── python/ # Python GEMM examples (01-11)
├── scripts/ # Build scripts
└── tests/ # Unit tests
```
---
## Example Documentation
| Directory | README |
|-----------|--------|
| GEMM C++ | [examples/gemm/cpp/README.md](examples/gemm/cpp/README.md) |
| GEMM Python | [examples/gemm/python/README.md](examples/gemm/python/README.md) |
| Codegen | [codegen/README.md](codegen/README.md) |
---
## Archived Content
Convolution examples and utilities have been archived to `ck-2/conv_archive/dispatcher/`:
- `examples/conv/cpp/` - 11 C++ convolution examples
- `examples/conv/python/` - 14 Python convolution examples
- `codegen/unified_conv_codegen.py` - Conv kernel generator
- `include/ck_tile/dispatcher/conv_*.hpp` - Conv headers
- `python/conv_utils.py` - Conv Python utilities
---
## License
MIT License - Copyright (c) 2025, Advanced Micro Devices, Inc.

View File

@@ -0,0 +1,109 @@
# CK Tile Dispatcher - Language Bindings
This directory contains language bindings for the CK Tile Dispatcher.
## Structure
```
bindings/
├── ctypes/ # Python ctypes bindings (C API)
│ ├── gemm_ctypes_lib.cpp # GEMM dispatcher C API
│ ├── conv_ctypes_lib.cpp # Convolution dispatcher C API (fwd + bwd_data)
│ ├── conv_bwdw_ctypes_lib.cpp # Convolution backward weight C API
│ ├── gpu_helper.cpp # CLI helper for Python
│ └── CMakeLists.txt
└── README.md
```
## ctypes Bindings
The ctypes bindings provide a C API that Python can load via `ctypes.CDLL()`.
### Building
```bash
cd build
cmake .. -DCMAKE_PREFIX_PATH=/opt/rocm
make dispatcher_gemm_lib dispatcher_conv_lib gpu_helper
```
### Usage from Python
```python
import ctypes
# Load the library
lib = ctypes.CDLL("path/to/libdispatcher_gemm_lib.so")
# Initialize
lib.dispatcher_init()
# Check if problem is supported
is_supported = lib.dispatcher_is_supported(M, N, K)
# Run GEMM
time_ms = ctypes.c_float()
result = lib.dispatcher_run_gemm(
A_ptr, B_ptr, C_ptr,
M, N, K,
ctypes.byref(time_ms)
)
# Cleanup
lib.dispatcher_cleanup()
```
### GEMM API
| Function | Description |
|----------|-------------|
| `dispatcher_init()` | Initialize the dispatcher |
| `dispatcher_is_supported(M, N, K)` | Check if problem size is supported |
| `dispatcher_select_kernel(M, N, K, name_buf, buf_size)` | Get kernel name for problem |
| `dispatcher_run_gemm(A, B, C, M, N, K, time_ms)` | Execute GEMM |
| `dispatcher_get_kernel_count()` | Get number of registered kernels |
| `dispatcher_export_registry_json()` | Export registry as JSON |
| `dispatcher_cleanup()` | Release resources |
### Convolution API
| Function | Description |
|----------|-------------|
| `conv_dispatcher_init()` | Initialize the dispatcher |
| `conv_dispatcher_is_supported(prob)` | Check if problem is supported |
| `conv_dispatcher_select_kernel(prob, name_buf, buf_size)` | Get kernel name |
| `conv_dispatcher_run(input, weight, output, prob, stream)` | Execute convolution |
| `conv_dispatcher_get_kernel_count()` | Get number of registered kernels |
| `conv_dispatcher_cleanup()` | Release resources |
## GPU Helper
The `gpu_helper` executable provides a CLI interface for Python:
```bash
./gpu_helper 1024 1024 1024 --validate
```
Output is JSON for easy parsing:
```json
{
"problem": {"M": 1024, "N": 1024, "K": 1024},
"kernel": "gemm_fp16_rcr_...",
"execution": {
"time_ms": 0.5,
"tflops": 4.2
},
"validation": {
"accuracy": 100.0
},
"status": "success"
}
```
## Examples
See the examples that use these bindings:
- **GEMM**: `dispatcher/examples/gemm/python/`
- **Conv**: `dispatcher/examples/conv/python/`

View File

@@ -0,0 +1,181 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# =============================================================================
# CK Tile Dispatcher - ctypes Bindings
# =============================================================================
#
# Provides shared libraries with C API for Python ctypes integration.
#
# Targets:
# - dispatcher_gemm_lib : GEMM dispatcher library
# - dispatcher_conv_lib : Convolution dispatcher library (forward + bwd_data)
# - dispatcher_conv_bwdw_lib : Convolution backward weight library
# - gpu_helper : GPU helper executable for Python
#
cmake_minimum_required(VERSION 3.16)
# Helper function to add a ctypes library
function(add_ctypes_library TARGET_NAME SOURCE_FILE)
cmake_parse_arguments(ARG "CONV" "KERNEL_HEADER" "" ${ARGN})
add_library(${TARGET_NAME} SHARED ${SOURCE_FILE})
target_include_directories(${TARGET_NAME} PRIVATE
${PROJECT_SOURCE_DIR}/include
${PROJECT_SOURCE_DIR}/dispatcher/include
)
target_link_libraries(${TARGET_NAME} PRIVATE
hip::device
)
# Force-include kernel header if provided
if(ARG_KERNEL_HEADER AND EXISTS ${ARG_KERNEL_HEADER})
target_compile_options(${TARGET_NAME} PRIVATE
-include ${ARG_KERNEL_HEADER}
)
if(ARG_CONV)
target_compile_definitions(${TARGET_NAME} PRIVATE CONV_KERNEL_AVAILABLE)
endif()
endif()
set_target_properties(${TARGET_NAME} PROPERTIES
POSITION_INDEPENDENT_CODE ON
CXX_STANDARD 17
)
endfunction()
# =============================================================================
# GEMM ctypes Library
# =============================================================================
# Find a generated GEMM kernel header for the library
file(GLOB GEMM_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/gemm_*.hpp")
if(GEMM_KERNEL_HEADERS)
list(GET GEMM_KERNEL_HEADERS 0 GEMM_KERNEL_HEADER)
message(STATUS "Found GEMM kernel for ctypes lib: ${GEMM_KERNEL_HEADER}")
add_ctypes_library(dispatcher_gemm_lib
gemm_ctypes_lib.cpp
KERNEL_HEADER ${GEMM_KERNEL_HEADER}
)
else()
message(STATUS "No GEMM kernel found for ctypes lib - building without kernel")
add_library(dispatcher_gemm_lib SHARED gemm_ctypes_lib.cpp)
target_include_directories(dispatcher_gemm_lib PRIVATE
${PROJECT_SOURCE_DIR}/include
${PROJECT_SOURCE_DIR}/dispatcher/include
)
target_link_libraries(dispatcher_gemm_lib PRIVATE hip::device)
endif()
# =============================================================================
# Convolution ctypes Library (supports forward + bwd_data)
# =============================================================================
# Look for forward kernels
file(GLOB CONV_FWD_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/conv_fwd_*.hpp")
# Look for backward data kernels
file(GLOB CONV_BWDD_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/conv_bwdd_*.hpp")
# Fallback: any conv kernel (for backwards compatibility)
file(GLOB CONV_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/conv_*.hpp")
add_library(dispatcher_conv_lib SHARED conv_ctypes_lib.cpp)
target_include_directories(dispatcher_conv_lib PRIVATE
${PROJECT_SOURCE_DIR}/include
${PROJECT_SOURCE_DIR}/dispatcher/include
)
target_link_libraries(dispatcher_conv_lib PRIVATE hip::device)
set_target_properties(dispatcher_conv_lib PROPERTIES
POSITION_INDEPENDENT_CODE ON
CXX_STANDARD 17
)
# Add forward kernel if available
if(CONV_FWD_KERNEL_HEADERS)
list(GET CONV_FWD_KERNEL_HEADERS 0 CONV_FWD_KERNEL_HEADER)
message(STATUS "Found Conv FWD kernel for ctypes lib: ${CONV_FWD_KERNEL_HEADER}")
target_compile_options(dispatcher_conv_lib PRIVATE -include ${CONV_FWD_KERNEL_HEADER})
target_compile_definitions(dispatcher_conv_lib PRIVATE CONV_KERNEL_AVAILABLE)
elseif(CONV_KERNEL_HEADERS)
# Fallback to any conv kernel
list(GET CONV_KERNEL_HEADERS 0 CONV_KERNEL_HEADER)
message(STATUS "Found Conv kernel for ctypes lib: ${CONV_KERNEL_HEADER}")
target_compile_options(dispatcher_conv_lib PRIVATE -include ${CONV_KERNEL_HEADER})
target_compile_definitions(dispatcher_conv_lib PRIVATE CONV_KERNEL_AVAILABLE)
else()
message(STATUS "No Conv FWD kernel found for ctypes lib - building without kernel")
endif()
# Add backward data kernel if available
if(CONV_BWDD_KERNEL_HEADERS)
list(GET CONV_BWDD_KERNEL_HEADERS 0 CONV_BWDD_KERNEL_HEADER)
message(STATUS "Found Conv BWD_DATA kernel for ctypes lib: ${CONV_BWDD_KERNEL_HEADER}")
target_compile_options(dispatcher_conv_lib PRIVATE -include ${CONV_BWDD_KERNEL_HEADER})
target_compile_definitions(dispatcher_conv_lib PRIVATE CONV_BWD_DATA_AVAILABLE)
endif()
# =============================================================================
# Convolution Backward Weight ctypes Library (separate lib for bwd_weight)
# =============================================================================
file(GLOB CONV_BWDW_KERNEL_HEADERS "${CMAKE_BINARY_DIR}/generated_kernels/conv_*bwd_weight*.hpp")
if(CONV_BWDW_KERNEL_HEADERS)
list(GET CONV_BWDW_KERNEL_HEADERS 0 CONV_BWDW_KERNEL_HEADER)
message(STATUS "Found Conv BwdWeight kernel for ctypes lib: ${CONV_BWDW_KERNEL_HEADER}")
add_library(dispatcher_conv_bwdw_lib SHARED conv_bwdw_ctypes_lib.cpp)
target_include_directories(dispatcher_conv_bwdw_lib PRIVATE
${PROJECT_SOURCE_DIR}/include
${PROJECT_SOURCE_DIR}/dispatcher/include
)
target_link_libraries(dispatcher_conv_bwdw_lib PRIVATE hip::device)
target_compile_options(dispatcher_conv_bwdw_lib PRIVATE
-include ${CONV_BWDW_KERNEL_HEADER}
)
target_compile_definitions(dispatcher_conv_bwdw_lib PRIVATE CONV_BWD_WEIGHT_AVAILABLE)
set_target_properties(dispatcher_conv_bwdw_lib PROPERTIES
POSITION_INDEPENDENT_CODE ON
CXX_STANDARD 17
)
else()
message(STATUS "No Conv BwdWeight kernel found for ctypes lib - building without kernel")
add_library(dispatcher_conv_bwdw_lib SHARED conv_bwdw_ctypes_lib.cpp)
target_include_directories(dispatcher_conv_bwdw_lib PRIVATE
${PROJECT_SOURCE_DIR}/include
${PROJECT_SOURCE_DIR}/dispatcher/include
)
target_link_libraries(dispatcher_conv_bwdw_lib PRIVATE hip::device)
set_target_properties(dispatcher_conv_bwdw_lib PROPERTIES
POSITION_INDEPENDENT_CODE ON
CXX_STANDARD 17
)
endif()
# =============================================================================
# GPU Helper Executable
# =============================================================================
if(GEMM_KERNEL_HEADERS)
add_executable(gpu_helper gpu_helper.cpp)
target_include_directories(gpu_helper PRIVATE
${PROJECT_SOURCE_DIR}/include
${PROJECT_SOURCE_DIR}/dispatcher/include
)
target_link_libraries(gpu_helper PRIVATE
hip::device
)
target_compile_options(gpu_helper PRIVATE
-include ${GEMM_KERNEL_HEADER}
)
set_target_properties(gpu_helper PROPERTIES
CXX_STANDARD 17
)
endif()

View File

@@ -0,0 +1,175 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
/**
* Convolution Backward Weight Dispatcher ctypes Library
*
* SEPARATE library for backward weight to avoid template conflicts with
* forward/backward_data kernels in the main conv_ctypes_lib.
*
* Usage from Python:
* lib = ctypes.CDLL("libdispatcher_conv_bwdw_lib.so")
* lib.conv_bwdw_init()
* lib.conv_bwdw_run(...)
*/
#include <cstring>
#include <vector>
#include <hip/hip_runtime.h>
// Minimal includes - matching the C++ example
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host/convolution_parameter.hpp"
#include "ck_tile/ops/gemm.hpp" // Must be before grouped_convolution for TileGemmTraits
#include "ck_tile/ops/grouped_convolution.hpp"
// Global state - minimal, no registry needed for direct launch
static bool g_bwdw_initialized = false;
extern "C" {
// =============================================================================
// Initialization (minimal - just sets flag)
// =============================================================================
int conv_bwdw_init()
{
g_bwdw_initialized = true;
return 0; // Return 0 on success (consistent with other init functions)
}
void conv_bwdw_cleanup() { g_bwdw_initialized = false; }
// =============================================================================
// Problem Structure (same as main library)
// =============================================================================
struct ConvBwdwProblemC
{
int N, G, C, K;
int input_d, input_h, input_w;
int filter_z, filter_y, filter_x;
int stride_d, stride_h, stride_w;
int pad_d, pad_h, pad_w;
int dilation_d, dilation_h, dilation_w;
};
// =============================================================================
// Backward Weight Execution
// =============================================================================
#ifdef CONV_BWD_WEIGHT_AVAILABLE
static ck_tile::conv::ConvParam build_conv_param(const ConvBwdwProblemC* prob)
{
const bool is_3d = (prob->input_d > 1 || prob->filter_z > 1);
if(is_3d)
{
return ck_tile::conv::ConvParam{3,
prob->G,
prob->N,
prob->K,
prob->C,
{prob->filter_z, prob->filter_y, prob->filter_x},
{prob->input_d, prob->input_h, prob->input_w},
{prob->stride_d, prob->stride_h, prob->stride_w},
{prob->dilation_d, prob->dilation_h, prob->dilation_w},
{prob->pad_d, prob->pad_h, prob->pad_w},
{prob->pad_d, prob->pad_h, prob->pad_w}};
}
else
{
return ck_tile::conv::ConvParam{2,
prob->G,
prob->N,
prob->K,
prob->C,
{prob->filter_y, prob->filter_x},
{prob->input_h, prob->input_w},
{prob->stride_h, prob->stride_w},
{prob->dilation_h, prob->dilation_w},
{prob->pad_h, prob->pad_w},
{prob->pad_h, prob->pad_w}};
}
}
static float run_bwd_weight_impl(const void* input_ptr,
const void* grad_output_ptr,
void* grad_weight_ptr,
const ConvBwdwProblemC* prob,
void* stream)
{
auto conv_param = build_conv_param(prob);
// Backward weight: A=input, B=grad_output, C=grad_weight
ck_tile::GroupedConvBwdWeightHostArgs args(conv_param,
input_ptr, // in_ptr = input
grad_weight_ptr, // wei_ptr = grad_weight (output)
{}, // ds_ptr
grad_output_ptr, // out_ptr = grad_output
1 // k_batch
);
ck_tile::stream_config stream_cfg{static_cast<hipStream_t>(stream), true, 1, 3, 10};
return SelectedConvBwdWeightLauncher::launch(args, stream_cfg);
}
#endif
float conv_bwdw_run(const void* input_ptr,
const void* grad_output_ptr,
void* grad_weight_ptr,
const ConvBwdwProblemC* prob,
void* stream)
{
#ifdef CONV_BWD_WEIGHT_AVAILABLE
// Validate all required pointers before kernel launch
if(!g_bwdw_initialized || !prob)
return -1.0f;
if(!input_ptr || !grad_output_ptr || !grad_weight_ptr)
return -1.0f; // Null data pointer would cause kernel crash
return run_bwd_weight_impl(input_ptr, grad_output_ptr, grad_weight_ptr, prob, stream);
#else
return -1.0f;
#endif
}
// =============================================================================
// Info
// =============================================================================
const char* conv_bwdw_version() { return "1.0.0"; }
int conv_bwdw_has_kernels()
{
#ifdef CONV_BWD_WEIGHT_AVAILABLE
return 1;
#else
return 0;
#endif
}
int conv_bwdw_get_kernel_count()
{
#ifdef CONV_BWD_WEIGHT_AVAILABLE
return 1;
#else
return 0;
#endif
}
int conv_bwdw_get_kernel_name(int index, char* buffer, int buffer_size)
{
#ifdef CONV_BWD_WEIGHT_AVAILABLE
if(index != 0 || !buffer || buffer_size <= 0)
return -1;
std::strncpy(buffer, CONV_BWD_WEIGHT_KERNEL_NAME, buffer_size - 1);
buffer[buffer_size - 1] = '\0';
return 0;
#else
return -1;
#endif
}
} // extern "C"

View File

@@ -0,0 +1,411 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
/**
* Convolution Dispatcher ctypes Library
*
* Provides C API for Python ctypes integration.
* Supports forward convolution. Backward operations require additional headers.
*
* REQUIRED: Forward kernel header must be force-included via -include flag.
* OPTIONAL: Backward kernels can be added with CONV_BWD_DATA_AVAILABLE/CONV_BWD_WEIGHT_AVAILABLE
*
* Usage from Python:
* lib = ctypes.CDLL("libdispatcher_conv.so")
* lib.conv_dispatcher_init()
* lib.conv_dispatcher_run(...)
*/
#include <cstring>
#include <memory>
#include <vector>
#include <hip/hip_runtime.h>
#include "ck_tile/dispatcher/conv_utils.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
using namespace ck_tile::dispatcher;
// Global state (using shared_ptr for safe memory management)
static std::shared_ptr<ConvRegistry> g_registry = nullptr;
static std::shared_ptr<ConvDispatcher> g_dispatcher = nullptr;
static std::vector<const ConvKernelInstance*> g_kernels;
extern "C" {
// =============================================================================
// Initialization
// =============================================================================
int conv_dispatcher_init()
{
if(g_registry)
return 0; // Already initialized
g_registry = std::make_shared<ConvRegistry>();
g_dispatcher = std::make_shared<ConvDispatcher>(g_registry.get());
// Register kernel configurations using simple ConvKernelSet
// (actual kernel launch uses the force-included SelectedConvKernelLauncher)
using namespace ck_tile::dispatcher::conv_decl;
// Forward kernels (required - must be force-included)
// Must match: conv_fwd_fp16_nhwgc_2d_compv4_cshuffle_intrawave_128x128x64_2x2x1_32x32x16_dsb
ConvKernelSet fwd_set;
fwd_set.add(ConvSignature().dtype("fp16").layout("nhwgc").conv_type("forward").dims(2),
ConvAlgorithm()
.tile(128, 128, 64) // tile_m x tile_n x tile_k
.wave(2, 2, 1)
.warp(32, 32, 16)
.pipeline("compv4")
.scheduler("intrawave"),
"gfx942");
g_registry->register_set(fwd_set, ConvRegistry::Priority::High);
#ifdef CONV_BWD_DATA_AVAILABLE
// Backward data kernels
// Must match: conv_bwdd_fp16_nhwgc_2d_compv3_cshuffle_intrawave_128x128x64_2x2x1_32x32x16
ConvKernelSet bwd_data_set;
bwd_data_set.add(ConvSignature().dtype("fp16").layout("nhwgc").conv_type("bwd_data").dims(2),
ConvAlgorithm()
.tile(128, 128, 64) // tile_m x tile_n x tile_k
.wave(2, 2, 1)
.warp(32, 32, 16)
.pipeline("compv3")
.scheduler("intrawave"),
"gfx942");
g_registry->register_set(bwd_data_set, ConvRegistry::Priority::High);
#endif
return 0;
}
int conv_dispatcher_cleanup()
{
// shared_ptr automatically handles cleanup when reset
g_dispatcher.reset();
g_registry.reset();
g_kernels.clear();
return 0;
}
// =============================================================================
// Registry Management
// =============================================================================
int conv_dispatcher_get_kernel_count()
{
if(!g_registry)
return 0;
return static_cast<int>(g_registry->size());
}
int conv_dispatcher_get_kernel_name(int index, char* buffer, int buffer_size)
{
if(index < 0 || !buffer || buffer_size <= 0)
return -1;
if(!g_registry)
return -1;
// Use registry to get kernel names (they are registered with full names)
const auto& kernels = g_registry->all_kernels();
if(static_cast<size_t>(index) >= kernels.size())
return -1;
const auto* kernel = kernels[index];
std::strncpy(buffer, kernel->name().c_str(), buffer_size - 1);
buffer[buffer_size - 1] = '\0';
return 0;
}
// =============================================================================
// Problem Definition
// =============================================================================
struct ConvProblemC
{
int N, G, C, K;
int input_d, input_h, input_w;
int filter_z, filter_y, filter_x;
int stride_d, stride_h, stride_w;
int pad_d, pad_h, pad_w;
int dilation_d, dilation_h, dilation_w;
int direction; // 0=forward, 1=bwd_data, 2=bwd_weight
};
// =============================================================================
// Kernel Selection
// =============================================================================
int conv_dispatcher_is_supported(const ConvProblemC* prob)
{
if(!g_registry || !prob)
return 0;
ConvProblem problem;
problem.N = prob->N;
problem.G = prob->G;
problem.C = prob->C;
problem.K = prob->K;
problem.input_spatial = {prob->input_d, prob->input_h, prob->input_w};
problem.filter_spatial = {prob->filter_z, prob->filter_y, prob->filter_x};
problem.stride = {prob->stride_d, prob->stride_h, prob->stride_w};
problem.padding = {prob->pad_d, prob->pad_h, prob->pad_w};
problem.dilation = {prob->dilation_d, prob->dilation_h, prob->dilation_w};
problem.op = static_cast<ConvOp>(prob->direction);
problem.compute_output_size();
const auto* kernel = g_dispatcher->select(problem);
return kernel ? 1 : 0;
}
int conv_dispatcher_select_kernel(const ConvProblemC* prob, char* kernel_name, int buffer_size)
{
if(!g_registry || !prob || !kernel_name || buffer_size <= 0)
return -1;
ConvProblem problem;
problem.N = prob->N;
problem.G = prob->G;
problem.C = prob->C;
problem.K = prob->K;
problem.input_spatial = {prob->input_d, prob->input_h, prob->input_w};
problem.filter_spatial = {prob->filter_z, prob->filter_y, prob->filter_x};
problem.stride = {prob->stride_d, prob->stride_h, prob->stride_w};
problem.padding = {prob->pad_d, prob->pad_h, prob->pad_w};
problem.dilation = {prob->dilation_d, prob->dilation_h, prob->dilation_w};
problem.op = static_cast<ConvOp>(prob->direction);
problem.compute_output_size();
const auto* kernel = g_dispatcher->select(problem);
if(!kernel)
return -1;
std::strncpy(kernel_name, kernel->name().c_str(), buffer_size - 1);
kernel_name[buffer_size - 1] = '\0';
return 0;
}
// =============================================================================
// Convolution Execution
// =============================================================================
// Helper to build ConvParam
static ck_tile::conv::ConvParam build_conv_param(const ConvProblemC* prob)
{
// Determine if this is 2D or 3D convolution
const bool is_3d = (prob->input_d > 1 || prob->filter_z > 1);
if(is_3d)
{
// 3D convolution: use all spatial dimensions
return ck_tile::conv::ConvParam{3,
prob->G,
prob->N,
prob->K,
prob->C,
{prob->filter_z, prob->filter_y, prob->filter_x},
{prob->input_d, prob->input_h, prob->input_w},
{prob->stride_d, prob->stride_h, prob->stride_w},
{prob->dilation_d, prob->dilation_h, prob->dilation_w},
{prob->pad_d, prob->pad_h, prob->pad_w},
{prob->pad_d, prob->pad_h, prob->pad_w}};
}
else
{
// 2D convolution: only use H, W dimensions
return ck_tile::conv::ConvParam{2,
prob->G,
prob->N,
prob->K,
prob->C,
{prob->filter_y, prob->filter_x},
{prob->input_h, prob->input_w},
{prob->stride_h, prob->stride_w},
{prob->dilation_h, prob->dilation_w},
{prob->pad_h, prob->pad_w},
{prob->pad_h, prob->pad_w}};
}
}
// Forward convolution (required - kernel header must be force-included)
static float run_forward(const void* input_ptr,
const void* weight_ptr,
void* output_ptr,
const ConvProblemC* prob,
void* stream)
{
auto conv_param = build_conv_param(prob);
ck_tile::GroupedConvFwdHostArgs<> args(conv_param, input_ptr, weight_ptr, {}, output_ptr, 1);
ck_tile::stream_config stream_cfg{static_cast<hipStream_t>(stream), true, 1, 3, 10};
// SelectedConvKernelLauncher is defined in the force-included forward kernel header
return SelectedConvKernelLauncher::launch(args, stream_cfg);
}
#ifdef CONV_BWD_DATA_AVAILABLE
// Backward data convolution (optional)
// Computes: grad_input = conv_bwd_data(weight, grad_output)
//
// Parameters:
// grad_output_ptr: dY - gradient from next layer (const, read-only INPUT)
// weight_ptr: W - frozen weights (const, read-only INPUT)
// grad_input_ptr: dX - gradient for input (writable, OUTPUT)
static float run_bwd_data(const void* grad_output_ptr,
const void* weight_ptr,
void* grad_input_ptr,
const ConvProblemC* prob,
void* stream)
{
auto conv_param = build_conv_param(prob);
// CK Tile API uses tensor POSITION names (from forward pass), not data flow:
// in_ptr = input tensor position = grad_input_ptr (dX, OUTPUT of bwd_data)
// wei_ptr = weight tensor = weight_ptr (W, const)
// out_ptr = output tensor position = grad_output_ptr (dY, INPUT to bwd_data)
ck_tile::GroupedConvBwdDataHostArgs args(
conv_param, grad_input_ptr, weight_ptr, {}, grad_output_ptr, 1);
ck_tile::stream_config stream_cfg{static_cast<hipStream_t>(stream), true, 1, 3, 10};
return SelectedConvBwdDataLauncher::launch(args, stream_cfg);
}
#endif
#ifdef CONV_BWD_WEIGHT_AVAILABLE
// Backward weight convolution (optional)
// Parameters:
// input_ptr: original forward input X (const, read-only)
// grad_output_ptr: gradient from next layer dY (const, read-only)
// grad_weight_ptr: gradient of weights dW (writable, OUTPUT)
static float run_bwd_weight(const void* input_ptr,
const void* grad_output_ptr,
void* grad_weight_ptr,
const ConvProblemC* prob,
void* stream)
{
auto conv_param = build_conv_param(prob);
// GroupedConvBwdWeightHostArgs constructor order:
// (param, in=X, wei=dW (output), ds, out=dY (input), k_batch)
// Note: wei_ptr is the OUTPUT (grad_weight), out_ptr is the INPUT (grad_output)
ck_tile::GroupedConvBwdWeightHostArgs args(
conv_param, input_ptr, grad_weight_ptr, {}, grad_output_ptr, 1);
ck_tile::stream_config stream_cfg{static_cast<hipStream_t>(stream), true, 1, 3, 10};
return SelectedConvBwdWeightLauncher::launch(args, stream_cfg);
}
#endif
/**
* @brief Execute convolution based on direction specified in prob
*
* Parameter mapping varies by direction:
* Forward (direction=0):
* input_ptr = X (input tensor)
* weight_ptr = W (weight tensor)
* output_ptr = Y (output buffer)
*
* Backward Data (direction=1):
* input_ptr = dY (grad_output - gradient from next layer)
* weight_ptr = W (weight tensor, frozen)
* output_ptr = dX (grad_input buffer)
*
* Backward Weight (direction=2):
* input_ptr = X (forward input tensor)
* weight_ptr = dY (grad_output - gradient from next layer)
* output_ptr = dW (grad_weight buffer)
*/
float conv_dispatcher_run(const void* input_ptr,
const void* weight_ptr,
void* output_ptr,
const ConvProblemC* prob,
void* stream)
{
// Validate all required pointers before kernel launch
if(!g_dispatcher || !prob)
return -1.0f;
if(!input_ptr || !weight_ptr || !output_ptr)
return -1.0f; // Null data pointer would cause kernel crash
// Build problem for kernel selection
ConvProblem problem;
problem.N = prob->N;
problem.G = prob->G;
problem.C = prob->C;
problem.K = prob->K;
problem.input_spatial = {prob->input_d, prob->input_h, prob->input_w};
problem.filter_spatial = {prob->filter_z, prob->filter_y, prob->filter_x};
problem.stride = {prob->stride_d, prob->stride_h, prob->stride_w};
problem.padding = {prob->pad_d, prob->pad_h, prob->pad_w};
problem.dilation = {prob->dilation_d, prob->dilation_h, prob->dilation_w};
problem.op = static_cast<ConvOp>(prob->direction);
problem.compute_output_size();
// Select kernel
const auto* kernel = g_dispatcher->select(problem);
if(!kernel)
return -1.0f;
// Dispatch based on direction
switch(prob->direction)
{
case 0: // Forward (always available)
return run_forward(input_ptr, weight_ptr, output_ptr, prob, stream);
#ifdef CONV_BWD_DATA_AVAILABLE
case 1: // Backward data
// Convention: caller passes (grad_output, weight, grad_input_buffer)
// in the (input_ptr, weight_ptr, output_ptr) slots respectively.
// run_bwd_data expects: (grad_output, weight, grad_input)
return run_bwd_data(input_ptr, weight_ptr, output_ptr, prob, stream);
#endif
#ifdef CONV_BWD_WEIGHT_AVAILABLE
case 2: // Backward weight
// Convention: caller passes (input, grad_output, grad_weight_buffer)
// in the (input_ptr, weight_ptr, output_ptr) slots respectively.
// run_bwd_weight expects: (input, grad_output, grad_weight)
return run_bwd_weight(input_ptr, weight_ptr, output_ptr, prob, stream);
#endif
default: return -1.0f;
}
}
// =============================================================================
// Info
// =============================================================================
const char* conv_dispatcher_version() { return "1.0.0"; }
int conv_dispatcher_has_kernels()
{
return 1; // Forward kernel is required
}
int conv_dispatcher_has_bwd_data()
{
#ifdef CONV_BWD_DATA_AVAILABLE
return 1;
#else
return 0;
#endif
}
int conv_dispatcher_has_bwd_weight()
{
#ifdef CONV_BWD_WEIGHT_AVAILABLE
return 1;
#else
return 0;
#endif
}
} // extern "C"

View File

@@ -0,0 +1,401 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
/**
* GEMM Dispatcher ctypes Library
*
* Provides C API for Python ctypes integration.
* Kernel header included via -include at compile time.
*
* Usage from Python:
* lib = ctypes.CDLL("libdispatcher_gemm.so")
* lib.dispatcher_init()
* lib.dispatcher_run_gemm(...)
*/
#include <hip/hip_runtime.h>
#include <cstdint>
#include <cstring>
#include <iostream>
#include <memory>
#include <sstream>
#include <string>
#include "ck_tile/dispatcher/dispatcher.hpp"
#include "ck_tile/dispatcher/registry.hpp"
#include "ck_tile/dispatcher/backends/generated_tile_backend.hpp"
// Kernel header included via -include compiler flag
// Defines: ADataType, BDataType, CDataType, AccDataType, SelectedKernel, KERNEL_NAME
// GPU architecture - can be overridden via -DGFX_ARCH="gfx90a" at compile time
#ifndef GFX_ARCH
#define GFX_ARCH "gfx942"
#endif
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::backends;
using Priority = ck_tile::dispatcher::Registry::Priority;
// Global dispatcher (initialized once, managed via shared_ptr for safe cleanup)
static std::shared_ptr<Dispatcher> g_dispatcher = nullptr;
static bool g_initialized = false;
#define HIP_CHECK(call) \
{ \
hipError_t err = call; \
if(err != hipSuccess) \
{ \
return -1; \
} \
}
extern "C" {
/**
* Initialize dispatcher with a kernel
* Must be called before run_gemm
*
* Returns: 0 on success, -1 on error
*/
int dispatcher_initialize()
{
if(g_initialized)
{
return 0; // Already initialized
}
// Create kernel key from the force-included kernel header
KernelKey key;
key.signature.dtype_a = DataType::FP16;
key.signature.dtype_b = DataType::FP16;
key.signature.dtype_c = DataType::FP16;
key.signature.dtype_acc = DataType::FP32;
key.signature.layout_a = LayoutTag::RowMajor;
key.signature.layout_b = LayoutTag::ColMajor;
key.signature.layout_c = LayoutTag::RowMajor;
key.signature.transpose_a = false;
key.signature.transpose_b = false;
key.signature.grouped = false;
key.signature.split_k = 1;
key.signature.elementwise_op = "PassThrough";
key.signature.num_d_tensors = 0;
key.signature.structured_sparsity = false;
key.algorithm.tile_shape = {128, 128, 32};
key.algorithm.wave_shape = {2, 2, 1};
key.algorithm.warp_tile_shape = {32, 32, 16};
key.algorithm.pipeline = Pipeline::CompV4;
key.algorithm.scheduler = Scheduler::Intrawave;
key.algorithm.epilogue = Epilogue::CShuffle;
key.algorithm.block_size = 256;
key.algorithm.double_buffer = true;
key.algorithm.persistent = false;
key.algorithm.preshuffle = false;
key.algorithm.transpose_c = false;
key.algorithm.num_wave_groups = 1;
key.gfx_arch = GFX_ARCH;
// Register kernel using types from force-included header
auto kernel =
create_generated_tile_kernel<SelectedKernel, ADataType, BDataType, CDataType, AccDataType>(
key, KERNEL_NAME);
Registry::instance().clear();
Registry::instance().register_kernel(kernel, Priority::High);
// Create dispatcher (using shared_ptr for safe memory management)
g_dispatcher = std::make_shared<Dispatcher>();
g_initialized = true;
return 0;
}
/**
* Get kernel tile configuration
*/
int dispatcher_get_kernel_config(int* tile_m,
int* tile_n,
int* tile_k,
int* warp_tile_m,
int* warp_tile_n,
int* warp_tile_k,
int* warp_m,
int* warp_n,
int* warp_k)
{
if(!g_initialized)
{
return -1;
}
auto kernels = Registry::instance().get_all();
if(kernels.empty())
{
return -1;
}
// Get configuration from first kernel
auto& key = kernels[0]->get_key();
auto& algo = key.algorithm;
if(tile_m)
*tile_m = algo.tile_shape.m;
if(tile_n)
*tile_n = algo.tile_shape.n;
if(tile_k)
*tile_k = algo.tile_shape.k;
if(warp_tile_m)
*warp_tile_m = algo.warp_tile_shape.m;
if(warp_tile_n)
*warp_tile_n = algo.warp_tile_shape.n;
if(warp_tile_k)
*warp_tile_k = algo.warp_tile_shape.k;
if(warp_m)
*warp_m = algo.wave_shape.m;
if(warp_n)
*warp_n = algo.wave_shape.n;
if(warp_k)
*warp_k = algo.wave_shape.k;
return 0;
}
/**
* Get the selected kernel name for a problem
*/
int dispatcher_select_kernel(int64_t M, int64_t N, int64_t K, char* name_buffer, int buffer_size)
{
if(!g_initialized || !name_buffer || buffer_size <= 0)
{
return -1;
}
Problem problem(M, N, K);
auto kernel = g_dispatcher->select_kernel(problem);
if(!kernel)
{
return -1;
}
std::string name = kernel->get_name();
strncpy(name_buffer, name.c_str(), buffer_size - 1);
name_buffer[buffer_size - 1] = '\0';
return 0;
}
/**
* Check if a problem size is supported by available kernels
*/
int dispatcher_is_supported(int64_t M, int64_t N, int64_t K)
{
if(!g_initialized)
{
return 0;
}
if(M <= 0 || N <= 0 || K <= 0)
{
return 0;
}
Problem problem(M, N, K);
auto kernel = g_dispatcher->select_kernel(problem);
return kernel != nullptr ? 1 : 0;
}
/**
* Run GEMM on GPU via dispatcher
*/
int dispatcher_run_gemm(
const void* A, const void* B, void* C, int64_t M, int64_t N, int64_t K, float* time_ms)
{
if(!g_initialized || !A || !B || !C)
{
return -1;
}
// First check if any kernel supports this problem
Problem problem(M, N, K);
auto kernel = g_dispatcher->select_kernel(problem);
if(!kernel)
{
if(time_ms)
{
*time_ms = -1.0f;
}
return -2; // No suitable kernel
}
// Cast to correct types (from force-included header)
const ADataType* A_host = static_cast<const ADataType*>(A);
const BDataType* B_host = static_cast<const BDataType*>(B);
CDataType* C_host = static_cast<CDataType*>(C);
// Allocate GPU memory
ADataType* A_dev = nullptr;
BDataType* B_dev = nullptr;
CDataType* C_dev = nullptr;
auto cleanup_gpu_mem = [&]() {
if(A_dev)
(void)hipFree(A_dev);
if(B_dev)
(void)hipFree(B_dev);
if(C_dev)
(void)hipFree(C_dev);
};
if(hipMalloc(&A_dev, M * K * sizeof(ADataType)) != hipSuccess)
{
cleanup_gpu_mem();
return -1;
}
if(hipMalloc(&B_dev, K * N * sizeof(BDataType)) != hipSuccess)
{
cleanup_gpu_mem();
return -1;
}
if(hipMalloc(&C_dev, M * N * sizeof(CDataType)) != hipSuccess)
{
cleanup_gpu_mem();
return -1;
}
// Copy input data to GPU
if(hipMemcpy(A_dev, A_host, M * K * sizeof(ADataType), hipMemcpyHostToDevice) != hipSuccess)
{
cleanup_gpu_mem();
return -1;
}
if(hipMemcpy(B_dev, B_host, K * N * sizeof(BDataType), hipMemcpyHostToDevice) != hipSuccess)
{
cleanup_gpu_mem();
return -1;
}
if(hipMemset(C_dev, 0, M * N * sizeof(CDataType)) != hipSuccess)
{
cleanup_gpu_mem();
return -1;
}
// Run GEMM via dispatcher
float exec_time;
try
{
exec_time = g_dispatcher->run(A_dev, B_dev, C_dev, problem);
}
catch(const std::exception& e)
{
cleanup_gpu_mem();
return -1;
}
// Copy result back to host
if(hipMemcpy(C_host, C_dev, M * N * sizeof(CDataType), hipMemcpyDeviceToHost) != hipSuccess)
{
cleanup_gpu_mem();
return -1;
}
if(time_ms)
{
*time_ms = exec_time;
}
cleanup_gpu_mem();
return 0;
}
/**
* Get kernel information
*/
const char* dispatcher_get_kernel_name() { return KERNEL_NAME; }
/**
* Initialize dispatcher (alias)
*/
int dispatcher_init() { return dispatcher_initialize(); }
/**
* Get the number of registered kernels
*/
int dispatcher_get_kernel_count() { return static_cast<int>(Registry::instance().size()); }
/**
* Export registry to JSON string
*/
static std::string g_json_buffer;
const char* dispatcher_export_registry_json()
{
auto& registry = Registry::instance();
std::ostringstream json;
json << "{\n";
json << " \"metadata\": {\n";
json << " \"timestamp\": \"" << __DATE__ << " " << __TIME__ << "\",\n";
json << " \"total_kernels\": " << registry.size() << ",\n";
json << " \"export_version\": \"1.0\",\n";
json << " \"dispatcher_version\": \"1.0.0\"\n";
json << " },\n";
json << " \"statistics\": {\n";
json << " \"by_datatype\": {},\n";
json << " \"by_pipeline\": {},\n";
json << " \"by_scheduler\": {}\n";
json << " },\n";
json << " \"kernels\": [\n";
auto kernels = registry.get_all();
for(size_t i = 0; i < kernels.size(); ++i)
{
auto& kernel = kernels[i];
auto& key = kernel->get_key();
auto& algo = key.algorithm;
std::string name = kernel->get_name();
json << " {\n";
json << " \"identifier\": \"" << key.encode_identifier() << "\",\n";
json << " \"name\": \"" << name << "\",\n";
json << " \"algorithm\": {\n";
json << " \"tile_shape\": {\"m\": " << algo.tile_shape.m
<< ", \"n\": " << algo.tile_shape.n << ", \"k\": " << algo.tile_shape.k << "},\n";
json << " \"wave_shape\": {\"m\": " << unsigned(algo.wave_shape.m)
<< ", \"n\": " << unsigned(algo.wave_shape.n)
<< ", \"k\": " << unsigned(algo.wave_shape.k) << "},\n";
json << " \"warp_tile_shape\": {\"m\": " << unsigned(algo.warp_tile_shape.m)
<< ", \"n\": " << unsigned(algo.warp_tile_shape.n)
<< ", \"k\": " << unsigned(algo.warp_tile_shape.k) << "},\n";
json << " \"block_size\": " << algo.block_size << ",\n";
json << " \"persistent\": " << (algo.persistent ? "true" : "false") << ",\n";
json << " \"double_buffer\": " << (algo.double_buffer ? "true" : "false") << ",\n";
json << " \"preshuffle\": " << (algo.preshuffle ? "true" : "false") << ",\n";
json << " \"transpose_c\": " << (algo.transpose_c ? "true" : "false") << "\n";
json << " }\n";
json << " }";
if(i < kernels.size() - 1)
{
json << ",";
}
json << "\n";
}
json << " ]\n";
json << "}\n";
g_json_buffer = json.str();
return g_json_buffer.c_str();
}
/**
* Cleanup dispatcher resources
*/
void dispatcher_cleanup()
{
g_dispatcher.reset();
g_initialized = false;
}
} // extern "C"

View File

@@ -0,0 +1,206 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
/**
* GPU Helper - C++ executable for GPU GEMM execution
*
* A CLI tool for Python to execute GPU GEMM with generated kernels.
* Usage: gpu_helper <M> <N> <K> [--validate]
*
* Kernel header included via -include flag at compile time.
*/
#include <iostream>
#include <vector>
#include <cstring>
#include <cmath>
#include <hip/hip_runtime.h>
#include "ck_tile/dispatcher/dispatcher.hpp"
#include "ck_tile/dispatcher/registry.hpp"
#include "ck_tile/dispatcher/backends/generated_tile_backend.hpp"
// Kernel header included via -include compiler flag
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::backends;
using Priority = ck_tile::dispatcher::Registry::Priority;
#define HIP_CHECK(call) \
{ \
hipError_t err = call; \
if(err != hipSuccess) \
{ \
std::cerr << "HIP_ERROR: " << hipGetErrorString(err) << "\n"; \
exit(1); \
} \
}
// CPU reference GEMM (for validation)
template <typename T>
void cpu_gemm(
const std::vector<T>& A, const std::vector<T>& B, std::vector<T>& C, int M, int N, int K)
{
for(int m = 0; m < M; m++)
{
for(int n = 0; n < N; n++)
{
float acc = 0.0f;
for(int k = 0; k < K; k++)
{
// A: RowMajor, B: ColumnMajor
acc += float(A[m * K + k]) * float(B[k + n * K]);
}
C[m * N + n] = T(acc);
}
}
}
int main(int argc, char** argv)
{
// Parse arguments
if(argc < 4)
{
std::cerr << "Usage: " << argv[0] << " <M> <N> <K> [--validate]\n";
std::cerr << "\nOptions:\n";
std::cerr << " M, N, K : Problem dimensions\n";
std::cerr << " --validate : Compare GPU results with CPU reference\n";
return 1;
}
int M = std::atoi(argv[1]);
int N = std::atoi(argv[2]);
int K = std::atoi(argv[3]);
bool validate = (argc > 4 && std::string(argv[4]) == "--validate");
// Output in JSON-like format for easy Python parsing
std::cout << "{" << std::endl;
std::cout << " \"problem\": {\"M\": " << M << ", \"N\": " << N << ", \"K\": " << K << "},"
<< std::endl;
std::cout << " \"kernel\": \"" << KERNEL_NAME << "\"," << std::endl;
// Register kernel
KernelKey key;
key.signature.dtype_a = DataType::FP16;
key.signature.dtype_b = DataType::FP16;
key.signature.dtype_c = DataType::FP16;
key.signature.dtype_acc = DataType::FP32;
key.signature.layout_a = LayoutTag::RowMajor;
key.signature.layout_b = LayoutTag::ColMajor;
key.signature.layout_c = LayoutTag::RowMajor;
key.signature.transpose_a = false;
key.signature.transpose_b = false;
key.signature.grouped = false;
key.signature.split_k = 1;
key.signature.elementwise_op = "PassThrough";
key.signature.num_d_tensors = 0;
key.signature.structured_sparsity = false;
key.algorithm.tile_shape = {128, 128, 32};
key.algorithm.wave_shape = {2, 2, 1};
key.algorithm.warp_tile_shape = {32, 32, 16};
key.algorithm.pipeline = Pipeline::CompV4;
key.algorithm.scheduler = Scheduler::Intrawave;
key.algorithm.epilogue = Epilogue::CShuffle;
key.algorithm.block_size = 256;
key.algorithm.double_buffer = true;
key.algorithm.persistent = false;
key.algorithm.preshuffle = false;
key.algorithm.transpose_c = false;
key.algorithm.num_wave_groups = 1;
key.gfx_arch = "gfx942";
auto kernel =
create_generated_tile_kernel<SelectedKernel, ADataType, BDataType, CDataType, AccDataType>(
key, KERNEL_NAME);
Registry::instance().clear();
Registry::instance().register_kernel(kernel, Priority::High);
Dispatcher dispatcher;
Problem problem(M, N, K);
auto selected = dispatcher.select_kernel(problem);
if(!selected)
{
std::cout << " \"error\": \"No kernel selected\"" << std::endl;
std::cout << "}" << std::endl;
return 1;
}
std::cout << " \"selected_kernel\": \"" << selected->get_name() << "\"," << std::endl;
// Prepare data: A=1, B=1, so C should be K
std::vector<ADataType> A_host(M * K, ADataType(1.0f));
std::vector<BDataType> B_host(K * N, BDataType(1.0f));
std::vector<CDataType> C_gpu(M * N);
// GPU execution
ADataType *A_dev, *B_dev;
CDataType* C_dev;
HIP_CHECK(hipMalloc(&A_dev, M * K * sizeof(ADataType)));
HIP_CHECK(hipMalloc(&B_dev, K * N * sizeof(BDataType)));
HIP_CHECK(hipMalloc(&C_dev, M * N * sizeof(CDataType)));
HIP_CHECK(hipMemcpy(A_dev, A_host.data(), M * K * sizeof(ADataType), hipMemcpyHostToDevice));
HIP_CHECK(hipMemcpy(B_dev, B_host.data(), K * N * sizeof(BDataType), hipMemcpyHostToDevice));
HIP_CHECK(hipMemset(C_dev, 0, M * N * sizeof(CDataType)));
float gpu_time = dispatcher.run(A_dev, B_dev, C_dev, problem);
HIP_CHECK(hipMemcpy(C_gpu.data(), C_dev, M * N * sizeof(CDataType), hipMemcpyDeviceToHost));
// Calculate performance
double flops = 2.0 * M * N * K;
double tflops = (flops / (gpu_time * 1e-3)) / 1e12;
std::cout << " \"execution\": {" << std::endl;
std::cout << " \"time_ms\": " << gpu_time << "," << std::endl;
std::cout << " \"tflops\": " << tflops << "," << std::endl;
std::cout << " \"flops\": " << (long long)flops << std::endl;
std::cout << " }," << std::endl;
// Validation
if(validate)
{
std::vector<CDataType> C_cpu(M * N);
cpu_gemm(A_host, B_host, C_cpu, M, N, K);
int correct = 0;
float max_error = 0.0f;
for(int i = 0; i < M * N; i++)
{
float gpu_val = float(C_gpu[i]);
float cpu_val = float(C_cpu[i]);
float error = std::abs(gpu_val - cpu_val) / (std::abs(cpu_val) + 1e-5f);
max_error = std::max(max_error, error);
if(error < 0.02f)
{
correct++;
}
}
float accuracy = 100.0f * correct / (M * N);
std::cout << " \"validation\": {" << std::endl;
std::cout << " \"accuracy\": " << accuracy << "," << std::endl;
std::cout << " \"max_error\": " << max_error << "," << std::endl;
std::cout << " \"correct_elements\": " << correct << "," << std::endl;
std::cout << " \"total_elements\": " << M * N << std::endl;
std::cout << " }," << std::endl;
}
std::cout << " \"status\": \"success\"" << std::endl;
std::cout << "}" << std::endl;
// Cleanup
HIP_CHECK(hipFree(A_dev));
HIP_CHECK(hipFree(B_dev));
HIP_CHECK(hipFree(C_dev));
return 0;
}

View File

@@ -0,0 +1,197 @@
# Adding New GPU Architecture Support
Guide for adding support for a new AMD GPU architecture to the CK Tile Dispatcher.
> **See also:** [Main Dispatcher README](../README.md) | [Codegen README](README.md)
## Overview
The dispatcher uses `arch_specs.json` as the **single source of truth** for GPU specifications:
```
arch_specs.json → generate_arch_specs.py → arch_specs_generated.py (Python)
→ arch_specs_generated.hpp (C++)
```
## Quick Start
```bash
# 1. Edit arch_specs.json
# 2. Run generator
python generate_arch_specs.py
# 3. Rebuild
cd ../build && cmake --build . -j8
# 4. Test
ctest
```
## Step-by-Step Guide
### Step 1: Edit arch_specs.json
Add new architecture under `"architectures"`:
```json
{
"architectures": {
"gfx1100": {
"family": "rdna3",
"description": "AMD Radeon RX 7000 series (RDNA3)",
"warp_size": 32,
"lds_capacity_kb": 64,
"warp_configs": [
[2, 4, 1],
[4, 2, 1]
],
"warp_tile_combos": {
"fp16_fp16_fp16": [[16, 16, 16], [32, 32, 16]],
"bf16_bf16_bf16": [[16, 16, 16], [32, 32, 16]]
}
}
}
}
```
### Step 2: Configuration Fields
| Field | Description | Example |
|-------|-------------|---------|
| `family` | GPU family | `"cdna3"`, `"rdna4"` |
| `description` | Human-readable name | `"AMD Instinct MI300"` |
| `warp_size` | Wave/warp size | `64` (CDNA), `32` (RDNA) |
| `lds_capacity_kb` | LDS memory in KB | `64` |
| `warp_configs` | Valid `[warp_m, warp_n, warp_k]` | `[[2,2,1], [4,4,1]]` |
| `warp_tile_combos` | Warp tiles per dtype | See below |
### Step 3: Warp Tile Combinations
Map data type combinations to valid warp tile sizes:
```json
"warp_tile_combos": {
"fp16_fp16_fp16": [[32, 32, 8], [16, 16, 16], [32, 32, 16]],
"bf16_bf16_bf16": [[32, 32, 8], [16, 16, 16]],
"fp8_fp8_fp16": [[32, 32, 16], [32, 32, 32]],
"int8_int8_int32": [[16, 16, 32], [32, 32, 16]]
}
```
Key format: `{A_dtype}_{B_dtype}_{C_dtype}`
### Step 4: Run Generator
```bash
cd dispatcher/codegen
python generate_arch_specs.py
```
This generates:
- `arch_specs_generated.py` (Python module)
- `../include/ck_tile/dispatcher/arch_specs_generated.hpp` (C++ header)
### Step 5: Rebuild and Test
```bash
cd ../build
cmake --build . -j8
ctest --output-on-failure
```
### Step 6: Verify
```python
from arch_filter import ArchFilter
filter = ArchFilter("gfx1100")
is_valid = filter.is_kernel_valid(
datatype_a="fp16", datatype_b="fp16", datatype_c="fp16",
tile_m=128, tile_n=128, tile_k=32,
warp_m=2, warp_n=2, warp_k=1,
warp_tile_m=16, warp_tile_n=16, warp_tile_k=16
)
print(f"Valid: {is_valid}")
```
## Reference
### Supported Data Types
| Key | Description |
|-----|-------------|
| `fp16` | Half precision (16-bit) |
| `bf16` | Brain float 16 |
| `fp32` | Single precision (32-bit) |
| `fp64` | Double precision (64-bit) |
| `fp8` | 8-bit float (E4M3) |
| `bf8` | 8-bit brain float (E5M2) |
| `int8` | 8-bit integer |
| `int4` | 4-bit integer |
### GPU Families
| Family | Description |
|--------|-------------|
| `cdna2` | MI200 series (gfx90a) |
| `cdna3` | MI300 series (gfx942) |
| `cdna4` | MI350 series (gfx950) |
| `rdna3` | RX 7000 series (gfx1100) |
| `rdna4` | RX 9000 series (gfx1201) |
### Pipeline LDS Limits
| Pipeline | LDS Limit |
|----------|-----------|
| `compv4` | 32 KB |
| `preshufflev2` | 32 KB |
| `default` | 64 KB |
## Troubleshooting
### "Unknown GPU architecture"
1. Check architecture key matches exactly (e.g., `"gfx942"` not `"GFX942"`)
2. Verify you ran `generate_arch_specs.py`
3. Rebuild C++ code
### Kernels being rejected
```python
from arch_filter import ArchFilter, KernelConfig
filter = ArchFilter("gfx942")
result = filter.validate_kernel(config)
print(f"Valid: {result.valid}")
for error in result.errors:
print(f" Error: {error}")
```
### Missing warp tile combination
1. Check `warp_tile_combos` in `arch_specs.json`
2. Ensure `[warp_tile_m, warp_tile_n, warp_tile_k]` is in the list
3. Verify data type key format
## File Structure
```
codegen/
├── arch_specs.json # Single source of truth (EDIT THIS)
├── generate_arch_specs.py # Generator script
├── arch_specs_generated.py # Generated Python module
└── ADDING_NEW_GPU.md # This file
include/ck_tile/dispatcher/
├── arch_specs_generated.hpp # Generated C++ header
└── arch_filter.hpp # C++ filter
```
## Best Practices
1. **Test thoroughly** - Run all tests after adding a new GPU
2. **Start minimal** - Add only validated configurations
3. **Document sources** - Note where warp tile combinations came from
4. **Keep in sync** - If using tile_engine, keep both updated
---
> **More info:** See [../README.md](../README.md) for full documentation.

View File

@@ -0,0 +1,125 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
# CK Tile GEMM Unified Code Generator
cmake_minimum_required(VERSION 3.16)
# Find Python
find_package(Python3 COMPONENTS Interpreter REQUIRED)
# Configuration
set(CODEGEN_SCRIPT "${CMAKE_CURRENT_SOURCE_DIR}/unified_gemm_codegen.py")
set(CODEGEN_CONFIG "${CMAKE_CURRENT_SOURCE_DIR}/default_config.json")
set(CODEGEN_OUTPUT_DIR "${CMAKE_BINARY_DIR}/generated/tile_gemm")
# Configurable options
set(CK_TILE_GEMM_DATATYPE "fp16" CACHE STRING "GEMM data type (fp16, bf16, fp32, fp8, bf8, int8)")
set(CK_TILE_GEMM_LAYOUT "rcr" CACHE STRING "GEMM layout (rcr, rrr, crr, ccr)")
set(CK_TILE_GEMM_VARIANTS "standard" CACHE STRING "GEMM variants (standard, preshuffle, multi_d)")
set(CK_TILE_GEMM_GPU_TARGET "gfx942" CACHE STRING "Target GPU architecture")
set(CK_TILE_GEMM_PARALLEL ON CACHE BOOL "Enable parallel generation")
# Custom target to run code generation
add_custom_target(generate_tile_gemm_kernels
COMMAND ${Python3_EXECUTABLE} ${CODEGEN_SCRIPT}
--output-dir ${CODEGEN_OUTPUT_DIR}
--datatype ${CK_TILE_GEMM_DATATYPE}
--layout ${CK_TILE_GEMM_LAYOUT}
--gpu-target ${CK_TILE_GEMM_GPU_TARGET}
--config ${CODEGEN_CONFIG}
--variants ${CK_TILE_GEMM_VARIANTS}
$<$<NOT:$<BOOL:${CK_TILE_GEMM_PARALLEL}>>:--no-parallel>
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
COMMENT "Generating CK Tile GEMM kernels and dispatcher wrappers..."
VERBATIM
)
# Create output directory
file(MAKE_DIRECTORY ${CODEGEN_OUTPUT_DIR})
# Add generated headers to include path
include_directories(${CODEGEN_OUTPUT_DIR})
# Installation
install(FILES
${CODEGEN_SCRIPT}
${CODEGEN_CONFIG}
README.md
DESTINATION share/ck_tile/codegen
)
# Helper function for projects to generate kernels
function(ck_tile_generate_gemm_kernels)
set(options PARALLEL)
set(oneValueArgs OUTPUT_DIR DATATYPE LAYOUT GPU_TARGET CONFIG)
set(multiValueArgs VARIANTS)
cmake_parse_arguments(ARG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
# Set defaults
if(NOT ARG_OUTPUT_DIR)
set(ARG_OUTPUT_DIR "${CMAKE_BINARY_DIR}/generated/tile_gemm")
endif()
if(NOT ARG_DATATYPE)
set(ARG_DATATYPE "fp16")
endif()
if(NOT ARG_LAYOUT)
set(ARG_LAYOUT "rcr")
endif()
if(NOT ARG_GPU_TARGET)
set(ARG_GPU_TARGET "gfx942")
endif()
if(NOT ARG_CONFIG)
set(ARG_CONFIG "${CMAKE_CURRENT_SOURCE_DIR}/default_config.json")
endif()
if(NOT ARG_VARIANTS)
set(ARG_VARIANTS "standard")
endif()
# Build command
set(CMD ${Python3_EXECUTABLE} ${CODEGEN_SCRIPT}
--output-dir ${ARG_OUTPUT_DIR}
--datatype ${ARG_DATATYPE}
--layout ${ARG_LAYOUT}
--gpu-target ${ARG_GPU_TARGET}
--config ${ARG_CONFIG}
--variants ${ARG_VARIANTS}
)
if(NOT ARG_PARALLEL)
list(APPEND CMD --no-parallel)
endif()
# Execute
execute_process(
COMMAND ${CMD}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
RESULT_VARIABLE RESULT
OUTPUT_VARIABLE OUTPUT
ERROR_VARIABLE ERROR
)
if(NOT RESULT EQUAL 0)
message(FATAL_ERROR "Failed to generate GEMM kernels:\n${ERROR}")
else()
message(STATUS "Generated GEMM kernels: ${OUTPUT}")
endif()
endfunction()
# Example usage documentation
message(STATUS "CK Tile GEMM Code Generator configured")
message(STATUS " Script: ${CODEGEN_SCRIPT}")
message(STATUS " Config: ${CODEGEN_CONFIG}")
message(STATUS " Output: ${CODEGEN_OUTPUT_DIR}")
message(STATUS "")
message(STATUS "To generate kernels:")
message(STATUS " cmake --build . --target generate_tile_gemm_kernels")
message(STATUS "")
message(STATUS "Or use CMake function:")
message(STATUS " ck_tile_generate_gemm_kernels(")
message(STATUS " OUTPUT_DIR ./generated")
message(STATUS " DATATYPE fp16")
message(STATUS " LAYOUT rcr")
message(STATUS " VARIANTS standard preshuffle multi_d")
message(STATUS " PARALLEL")
message(STATUS " )")

View File

@@ -0,0 +1,123 @@
# CK Tile GEMM Unified Code Generator
Single source of truth for all GEMM kernel generation.
> **See also:** [Main Dispatcher README](../README.md) for installation and core concepts.
## Quick Start
```bash
cd dispatcher/codegen
# Generate standard FP16 kernels
python3 unified_gemm_codegen.py \
--output-dir ../build/generated_kernels \
--datatype fp16 \
--layout rcr \
--variants standard
# Generate all variants
python3 unified_gemm_codegen.py \
--output-dir ../build/generated_kernels \
--variants standard preshuffle multi_d
```
## Using from Python
```python
from ctypes_utils import CodegenRunner, KernelConfig
# Generate from specific config
config = KernelConfig(tile_m=256, tile_n=256, tile_k=64)
codegen = CodegenRunner()
result = codegen.generate_from_config(config)
# Generate variant
result = codegen.generate("preshuffle")
# Generate all
results = codegen.generate_all()
```
## Command Line Options
| Option | Values | Description |
|--------|--------|-------------|
| `--output-dir` | path | Output directory |
| `--datatype` | `fp16`, `bf16`, `fp32`, `int8` | Data type |
| `--layout` | `rcr`, `rrr`, `crr`, `ccr` | Matrix layouts |
| `--gpu-target` | `gfx942`, `gfx90a`, `gfx950` | Target GPU |
| `--variants` | `standard`, `preshuffle`, `multi_d` | Kernel variants |
| `--preselected` | `fp16_rcr_essential`, etc. | Predefined kernel set |
### Layout Notation
- `R` = Row-major, `C` = Column-major
- Order: A, B, C (e.g., `rcr` = A row, B col, C row)
## Variants
### Standard
Basic GEMM: `C = A × B`
### PreShuffle
Optimized weight access with LDS pre-shuffling. Best for large matrices.
### Multi-D
Element-wise fusion: `C = op(A × B + D0 + D1 + ...)`
Supported ops: `PassThrough`, `MultiDAdd`, `Relu`, `Gelu`, `Sigmoid`, `Tanh`
## Output Structure
```
generated_kernels/
├── gemm_fp16_rcr_compv4_..._128x128x32_....hpp
├── gemm_fp16_rcr_compv4_..._preshuffle.hpp
├── gemm_fp16_rcr_compv4_..._multid_Relu_d1.hpp
└── ...
```
## Configuration Files
### arch_specs.json
GPU architecture specifications (single source of truth):
```json
{
"architectures": {
"gfx942": {
"family": "cdna3",
"warp_size": 64,
"warp_configs": [[2, 2, 1], [4, 4, 1]],
...
}
}
}
```
### preselected_kernels.py
Curated kernel sets for common use cases.
## Adding New GPU Support
See [ADDING_NEW_GPU.md](ADDING_NEW_GPU.md) for complete guide.
Quick steps:
1. Edit `arch_specs.json`
2. Run `python generate_arch_specs.py`
3. Rebuild
## Troubleshooting
| Issue | Solution |
|-------|----------|
| "Arguments not supported" | Check tile config validity |
| Missing element-wise op | Check `elementwise_ops.hpp` |
| Compilation errors | Verify C++17, include paths |
---
> **More info:** See [../README.md](../README.md) for full documentation.

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,270 @@
{
"_comment": "Single source of truth for GPU architecture specifications. Edit this file to add new GPU support.",
"_version": "1.2.0",
"_instructions": "See ADDING_NEW_GPU.md for instructions on adding new GPU support.",
"_supported_arch_note": "CK Tile supports: GFX9 (gfx908, gfx90a, gfx942, gfx950), GFX10.3 (gfx103x), GFX11 (gfx110x, gfx115x), GFX12 (gfx120x)",
"architectures": {
"gfx908": {
"family": "cdna1",
"target_family": "gfx9",
"architecture": "cdna",
"description": "AMD Instinct MI100",
"warp_size": 64,
"lds_capacity_kb": 64,
"warp_configs": [
[1, 4, 1],
[2, 2, 1],
[4, 1, 1]
],
"warp_tile_combos": {
"fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]],
"fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
"bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
"int8_int8_int32": [[32, 32, 16], [16, 16, 32]]
}
},
"gfx90a": {
"family": "cdna2",
"target_family": "gfx9",
"architecture": "cdna",
"description": "AMD Instinct MI200 series",
"warp_size": 64,
"lds_capacity_kb": 64,
"warp_configs": [
[1, 4, 1],
[2, 2, 1],
[4, 1, 1]
],
"warp_tile_combos": {
"fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]],
"fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
"bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
"fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32]],
"bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32]],
"int8_int8_int32": [[32, 32, 16], [16, 16, 32]]
}
},
"gfx942": {
"family": "cdna3",
"target_family": "gfx9",
"architecture": "cdna",
"description": "AMD Instinct MI300 series",
"warp_size": 64,
"lds_capacity_kb": 64,
"warp_configs": [
[1, 4, 1],
[2, 2, 1],
[4, 1, 1]
],
"warp_tile_combos": {
"fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]],
"fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
"bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
"fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
"fp8_bf8_fp32": [[32, 32, 16], [16, 16, 32], [32, 32, 32]],
"bf8_fp8_fp32": [[32, 32, 16]],
"bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
"int8_int8_int32": [[32, 32, 16], [16, 16, 32]]
}
},
"gfx950": {
"family": "cdna4",
"target_family": "gfx9",
"architecture": "cdna",
"description": "AMD Instinct MI350 series",
"warp_size": 64,
"lds_capacity_kb": 160,
"warp_configs": [
[1, 4, 1],
[2, 2, 1],
[4, 1, 1]
],
"warp_tile_combos": {
"fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]],
"fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
"bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [4, 64, 16], [64, 4, 16]],
"fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64], [16, 16, 128], [32, 32, 64]],
"fp8_bf8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 128], [32, 32, 64]],
"bf8_fp8_fp32": [[32, 32, 16], [16, 16, 128], [32, 32, 64]],
"bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64], [16, 16, 128], [32, 32, 64]],
"int8_int8_int32": [[32, 32, 16], [16, 16, 32]],
"pk_fp4_pk_fp4_fp32": [[16, 16, 128]]
}
},
"gfx1100": {
"family": "rdna3",
"target_family": "gfx11",
"architecture": "rdna",
"description": "AMD Radeon RX 7900 series (RDNA3)",
"warp_size": 32,
"lds_capacity_kb": 64,
"warp_configs": [
[2, 4, 1],
[1, 8, 1],
[8, 1, 1],
[4, 2, 1]
],
"warp_tile_combos": {
"fp16_fp16_fp32": [[16, 16, 16]],
"bf16_bf16_fp32": [[16, 16, 16]],
"int8_int8_int32": [[16, 16, 16]]
}
},
"gfx1200": {
"family": "rdna4",
"target_family": "gfx12",
"architecture": "rdna",
"description": "AMD Radeon RX 9000 series (RDNA4)",
"warp_size": 32,
"lds_capacity_kb": 64,
"warp_configs": [
[2, 4, 1],
[1, 8, 1],
[8, 1, 1],
[4, 2, 1]
],
"warp_tile_combos": {
"fp16_fp16_fp32": [[16, 16, 16]],
"bf16_bf16_fp32": [[16, 16, 16]],
"fp8_fp8_fp32": [[16, 16, 16]],
"bf8_bf8_fp32": [[16, 16, 16]],
"fp8_bf8_fp32": [[16, 16, 16]],
"bf8_fp8_fp32": [[16, 16, 16]],
"int8_int8_int32": [[16, 16, 16]]
}
},
"gfx1201": {
"family": "rdna4",
"target_family": "gfx12",
"architecture": "rdna",
"description": "AMD Radeon RX 9000 series (RDNA4)",
"warp_size": 32,
"lds_capacity_kb": 64,
"warp_configs": [
[2, 4, 1],
[1, 8, 1],
[8, 1, 1],
[4, 2, 1]
],
"warp_tile_combos": {
"fp16_fp16_fp32": [[16, 16, 16]],
"bf16_bf16_fp32": [[16, 16, 16]],
"fp8_fp8_fp32": [[16, 16, 16]],
"bf8_bf8_fp32": [[16, 16, 16]],
"fp8_bf8_fp32": [[16, 16, 16]],
"bf8_fp8_fp32": [[16, 16, 16]],
"int8_int8_int32": [[16, 16, 16]]
}
}
},
"element_sizes": {
"fp16": 2,
"bf16": 2,
"fp32": 4,
"fp64": 8,
"fp8": 1,
"bf8": 1,
"int8": 1,
"int4": 0.5,
"pk_fp4": 0.5,
"int32": 4
},
"datatype_cpp_map": {
"_comment": "Maps dtype string to CK Tile C++ type for code generation",
"fp16": "ck_tile::half_t",
"bf16": "ck_tile::bf16_t",
"fp32": "float",
"fp64": "double",
"fp8": "ck_tile::fp8_t",
"bf8": "ck_tile::bf8_t",
"int8": "ck_tile::int8_t",
"int4": "ck_tile::pk_int4_t",
"pk_fp4": "ck_tile::pk_fp4_t",
"int32": "ck_tile::int32_t"
},
"dtype_combinations": {
"_comment": "All valid (A, B) -> Acc combinations for GEMM from warp_gemm_dispatcher.hpp",
"fp32_fp32": {"acc": "fp32", "notes": "Full precision"},
"fp16_fp16": {"acc": "fp32", "notes": "Standard half precision"},
"bf16_bf16": {"acc": "fp32", "notes": "Brain float 16"},
"fp8_fp8": {"acc": "fp32", "notes": "FP8 E4M3"},
"fp8_bf8": {"acc": "fp32", "notes": "Mixed FP8/BF8"},
"bf8_fp8": {"acc": "fp32", "notes": "Mixed BF8/FP8"},
"bf8_bf8": {"acc": "fp32", "notes": "BF8 E5M2"},
"int8_int8": {"acc": "int32", "notes": "Integer GEMM"},
"pk_fp4_pk_fp4": {"acc": "fp32", "notes": "Packed 4-bit float"}
},
"layout_cpp_map": {
"_comment": "Maps layout character to CK Tile C++ type",
"r": "ck_tile::tensor_layout::gemm::RowMajor",
"c": "ck_tile::tensor_layout::gemm::ColumnMajor"
},
"pipeline_lds_limits": {
"_comment": "LDS capacity limits in bytes for different pipeline types",
"mem": 65536,
"compv1": 65536,
"compv2": 65536,
"compv3": 65536,
"compv4": 32768,
"compv5": 65536,
"preshufflev1": 32768,
"preshufflev2": 32768,
"default": 65536
},
"unsupported_trait_combos": {
"_comment": "Only 'mem' pipeline supports interwave scheduler. All compute pipelines only support intrawave.",
"combinations": [
["compv3", "cshuffle", "interwave"],
["compv3", "default", "interwave"],
["compv4", "cshuffle", "interwave"],
["compv4", "default", "interwave"],
["compv5", "cshuffle", "interwave"],
["compv5", "default", "interwave"],
["compv6", "cshuffle", "interwave"],
["compv6", "default", "interwave"],
["comp_async", "cshuffle", "interwave"],
["comp_async", "default", "interwave"]
]
},
"preshuffle_warp_tile_combos": {
"_comment": "Preshuffle-specific warp tile combinations (subset of standard GEMM, no [4, 64, 16])",
"gfx90a": {
"fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16]],
"bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16]],
"fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32]],
"bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32]]
},
"gfx942": {
"fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16]],
"bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16]],
"fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
"bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]],
"int8_int8_int32": [[16, 16, 32], [32, 32, 16]]
},
"gfx950": {
"fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16]],
"bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32], [64, 4, 16]],
"fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64], [16, 16, 128], [32, 32, 64]],
"bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32], [16, 16, 128], [32, 32, 64]]
}
},
"preshuffle_pipelines": {
"_comment": "Pipelines supported for preshuffle GEMM variant",
"supported": ["preshufflev2"]
}
}

View File

@@ -0,0 +1,358 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
AUTO-GENERATED FILE - DO NOT EDIT DIRECTLY!
Generated from: arch_specs.json
Generated at: 2026-01-05T19:34:01.224422
To update this file:
1. Edit arch_specs.json
2. Run: python generate_arch_specs.py
This module provides architecture-specific configurations for kernel filtering.
"""
from typing import Dict, List, Set, Tuple
# =============================================================================
# Architecture Data (Generated from arch_specs.json)
# =============================================================================
# GPU architecture to family mapping
ARCH_FAMILY_MAP: Dict[str, str] = {
"gfx908": "cdna1",
"gfx90a": "cdna2",
"gfx942": "cdna3",
"gfx950": "cdna4",
"gfx1100": "rdna3",
"gfx1200": "rdna4",
"gfx1201": "rdna4",
}
# Element size in bytes for each data type
ELEMENT_SIZE_MAP: Dict[str, float] = {
"fp16": 2,
"bf16": 2,
"fp32": 4,
"fp64": 8,
"fp8": 1,
"bf8": 1,
"int8": 1,
"int4": 0.5,
"pk_fp4": 0.5,
"int32": 4,
}
# Supported warp configurations per architecture [warp_m, warp_n, warp_k]
WARP_SUPPORTED_COMBINATIONS: Dict[str, List[List[int]]] = {
"gfx908": [[1, 4, 1], [2, 2, 1], [4, 1, 1]],
"gfx90a": [[1, 4, 1], [2, 2, 1], [4, 1, 1]],
"gfx942": [[1, 4, 1], [2, 2, 1], [4, 1, 1]],
"gfx950": [[1, 4, 1], [2, 2, 1], [4, 1, 1]],
"gfx1100": [[2, 4, 1], [1, 8, 1], [8, 1, 1], [4, 2, 1]],
"gfx1200": [[2, 4, 1], [1, 8, 1], [8, 1, 1], [4, 2, 1]],
"gfx1201": [[2, 4, 1], [1, 8, 1], [8, 1, 1], [4, 2, 1]],
}
# Supported warp tile combinations: arch -> dtype_key -> [[warp_tile_m, n, k], ...]
WARP_TILE_SUPPORTED_COMBINATIONS: Dict[str, Dict[str, List[List[int]]]] = {
"gfx908": {
"fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]],
"fp16_fp16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
"bf16_bf16_fp32": [[32, 32, 8], [16, 16, 16], [32, 32, 16], [16, 16, 32]],
"int8_int8_int32": [[32, 32, 16], [16, 16, 32]],
},
"gfx90a": {
"fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]],
"fp16_fp16_fp32": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"bf16_bf16_fp32": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32]],
"bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32]],
"int8_int8_int32": [[32, 32, 16], [16, 16, 32]],
},
"gfx942": {
"fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]],
"fp16_fp16_fp32": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"bf16_bf16_fp32": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
"fp8_bf8_fp32": [[32, 32, 16], [16, 16, 32], [32, 32, 32]],
"bf8_fp8_fp32": [[32, 32, 16]],
"bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
"int8_int8_int32": [[32, 32, 16], [16, 16, 32]],
},
"gfx950": {
"fp32_fp32_fp32": [[16, 16, 4], [16, 16, 16]],
"fp16_fp16_fp32": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"bf16_bf16_fp32": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[4, 64, 16],
[64, 4, 16],
],
"fp8_fp8_fp32": [
[32, 32, 16],
[32, 32, 32],
[16, 16, 32],
[16, 16, 64],
[16, 16, 128],
[32, 32, 64],
],
"fp8_bf8_fp32": [
[32, 32, 16],
[32, 32, 32],
[16, 16, 32],
[16, 16, 128],
[32, 32, 64],
],
"bf8_fp8_fp32": [[32, 32, 16], [16, 16, 128], [32, 32, 64]],
"bf8_bf8_fp32": [
[32, 32, 16],
[32, 32, 32],
[16, 16, 32],
[16, 16, 64],
[16, 16, 128],
[32, 32, 64],
],
"int8_int8_int32": [[32, 32, 16], [16, 16, 32]],
"pk_fp4_pk_fp4_fp32": [[16, 16, 128]],
},
"gfx1100": {
"fp16_fp16_fp32": [[16, 16, 16]],
"bf16_bf16_fp32": [[16, 16, 16]],
"int8_int8_int32": [[16, 16, 16]],
},
"gfx1200": {
"fp16_fp16_fp32": [[16, 16, 16]],
"bf16_bf16_fp32": [[16, 16, 16]],
"fp8_fp8_fp32": [[16, 16, 16]],
"bf8_bf8_fp32": [[16, 16, 16]],
"fp8_bf8_fp32": [[16, 16, 16]],
"bf8_fp8_fp32": [[16, 16, 16]],
"int8_int8_int32": [[16, 16, 16]],
},
"gfx1201": {
"fp16_fp16_fp32": [[16, 16, 16]],
"bf16_bf16_fp32": [[16, 16, 16]],
"fp8_fp8_fp32": [[16, 16, 16]],
"bf8_bf8_fp32": [[16, 16, 16]],
"fp8_bf8_fp32": [[16, 16, 16]],
"bf8_fp8_fp32": [[16, 16, 16]],
"int8_int8_int32": [[16, 16, 16]],
},
}
# Preshuffle-specific warp tile combinations (subset of standard GEMM)
PRESHUFFLE_WARP_TILE_SUPPORTED_COMBINATIONS: Dict[str, Dict[str, List[List[int]]]] = {
"gfx90a": {
"fp16_fp16_fp32": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[64, 4, 16],
],
"bf16_bf16_fp32": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[64, 4, 16],
],
"fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32]],
"bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32]],
},
"gfx942": {
"fp16_fp16_fp32": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[64, 4, 16],
],
"bf16_bf16_fp32": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[64, 4, 16],
],
"fp8_fp8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 32], [16, 16, 64]],
"bf8_bf8_fp32": [[32, 32, 16], [32, 32, 32], [16, 16, 64], [16, 16, 32]],
"int8_int8_int32": [[16, 16, 32], [32, 32, 16]],
},
"gfx950": {
"fp16_fp16_fp32": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[64, 4, 16],
],
"bf16_bf16_fp32": [
[32, 32, 8],
[16, 16, 16],
[32, 32, 16],
[16, 16, 32],
[64, 4, 16],
],
"fp8_fp8_fp32": [
[32, 32, 16],
[32, 32, 32],
[16, 16, 32],
[16, 16, 64],
[16, 16, 128],
[32, 32, 64],
],
"bf8_bf8_fp32": [
[32, 32, 16],
[32, 32, 32],
[16, 16, 64],
[16, 16, 32],
[16, 16, 128],
[32, 32, 64],
],
},
}
# Preshuffle-supported pipelines
PRESHUFFLE_PIPELINES: List[str] = ["preshufflev2"]
# LDS capacity limits per pipeline type (in bytes)
LDS_CAPACITY_LIMITS: Dict[str, int] = {
"mem": 65536,
"compv1": 65536,
"compv2": 65536,
"compv3": 65536,
"compv4": 32768,
"compv5": 65536,
"preshufflev1": 32768,
"preshufflev2": 32768,
"default": 65536,
}
# Unsupported trait combinations: (pipeline, epilogue, scheduler)
TRAIT_UNSUPPORTED_COMBINATIONS: Set[Tuple[str, str, str]] = {
("compv3", "cshuffle", "interwave"),
("compv3", "default", "interwave"),
("compv4", "cshuffle", "interwave"),
("compv4", "default", "interwave"),
("compv5", "cshuffle", "interwave"),
("compv5", "default", "interwave"),
("compv6", "cshuffle", "interwave"),
("compv6", "default", "interwave"),
("comp_async", "cshuffle", "interwave"),
("comp_async", "default", "interwave"),
}
# Valid dtype combinations: (A_dtype, B_dtype) -> acc_dtype and notes
DTYPE_COMBINATIONS: Dict[str, Dict[str, str]] = {
"fp32_fp32": {"acc": "fp32", "notes": "Full precision"},
"fp16_fp16": {"acc": "fp32", "notes": "Standard half precision"},
"bf16_bf16": {"acc": "fp32", "notes": "Brain float 16"},
"fp8_fp8": {"acc": "fp32", "notes": "FP8 E4M3"},
"fp8_bf8": {"acc": "fp32", "notes": "Mixed FP8/BF8"},
"bf8_fp8": {"acc": "fp32", "notes": "Mixed BF8/FP8"},
"bf8_bf8": {"acc": "fp32", "notes": "BF8 E5M2"},
"int8_int8": {"acc": "int32", "notes": "Integer GEMM"},
"pk_fp4_pk_fp4": {"acc": "fp32", "notes": "Packed 4-bit float"},
}
# =============================================================================
# Helper Functions
# =============================================================================
def get_supported_archs() -> List[str]:
"""Get list of all supported GPU architectures."""
return list(ARCH_FAMILY_MAP.keys())
def get_arch_family(gpu_arch: str) -> str:
"""Get the GPU family for an architecture."""
return ARCH_FAMILY_MAP.get(gpu_arch.lower(), "unknown")
def get_element_size(dtype: str) -> float:
"""Get element size in bytes for a data type."""
return ELEMENT_SIZE_MAP.get(dtype.lower(), 2.0)
def get_warp_configs(gpu_arch: str) -> List[List[int]]:
"""Get supported warp configurations for an architecture."""
return WARP_SUPPORTED_COMBINATIONS.get(gpu_arch.lower(), [])
def get_warp_tile_combos(gpu_arch: str, dtype_key: str) -> List[List[int]]:
"""Get supported warp tile combinations for arch and data types."""
gpu_combos = WARP_TILE_SUPPORTED_COMBINATIONS.get(gpu_arch.lower(), {})
return gpu_combos.get(dtype_key.lower(), [])
def get_lds_limit(pipeline: str) -> int:
"""Get LDS capacity limit for a pipeline type."""
return LDS_CAPACITY_LIMITS.get(pipeline.lower(), LDS_CAPACITY_LIMITS["default"])
def is_trait_combo_unsupported(pipeline: str, epilogue: str, scheduler: str) -> bool:
"""Check if a trait combination is unsupported."""
return (
pipeline.lower(),
epilogue.lower(),
scheduler.lower(),
) in TRAIT_UNSUPPORTED_COMBINATIONS
def get_dtype_info(dtype_a: str, dtype_b: str) -> Dict[str, str]:
"""Get accumulator type and notes for a dtype combination."""
key = f"{dtype_a.lower()}_{dtype_b.lower()}"
return DTYPE_COMBINATIONS.get(key, {"acc": "fp32", "notes": "unknown"})
def is_dtype_combo_valid(dtype_a: str, dtype_b: str) -> bool:
"""Check if a dtype combination is valid."""
key = f"{dtype_a.lower()}_{dtype_b.lower()}"
return key in DTYPE_COMBINATIONS
def get_valid_dtype_combos() -> List[str]:
"""Get list of all valid dtype combinations."""
return list(DTYPE_COMBINATIONS.keys())

View File

@@ -0,0 +1,27 @@
{
"tile_config": {
"tile_m": [128, 256],
"tile_n": [128, 256],
"tile_k": [32, 64],
"warp_m": [2, 4],
"warp_n": [2, 4],
"warp_k": [1],
"warp_tile_m": [16, 32],
"warp_tile_n": [16, 32],
"warp_tile_k": [16]
},
"trait_config": {
"pipeline": ["compv4"],
"epilogue": ["cshuffle"],
"scheduler": ["intrawave"],
"pad_m": [false],
"pad_n": [false],
"pad_k": [false],
"persistent": [false, true]
},
"multi_d_config": {
"elementwise_ops": ["MultiDAdd", "Relu", "Gelu"],
"num_d_tensors": [1, 2]
}
}

View File

@@ -0,0 +1,452 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Architecture Specs Generator
Generates both Python and C++ code from a single JSON source of truth.
This ensures consistency between Python codegen and C++ runtime filtering.
Usage:
python generate_arch_specs.py [--json arch_specs.json] [--output-dir .]
# Regenerate after editing arch_specs.json:
python generate_arch_specs.py
Output:
- arch_specs_generated.py (Python module with arch data)
- arch_specs_generated.hpp (C++ header with arch data)
"""
import json
import argparse
from pathlib import Path
from datetime import datetime
from typing import Dict, Any
SCRIPT_DIR = Path(__file__).parent
def load_arch_specs(json_path: Path) -> Dict[str, Any]:
"""Load architecture specifications from JSON file."""
with open(json_path) as f:
return json.load(f)
def generate_python_module(specs: Dict[str, Any], output_path: Path):
"""Generate Python module from arch specs."""
timestamp = datetime.now().isoformat()
# Extract data
archs = specs["architectures"]
element_sizes = specs["element_sizes"]
pipeline_limits = specs["pipeline_lds_limits"]
unsupported = specs["unsupported_trait_combos"]["combinations"]
# Build warp configs dict
warp_configs_str = "{\n"
for arch, data in archs.items():
warp_configs_str += f' "{arch}": {data["warp_configs"]},\n'
warp_configs_str += "}"
# Build warp tile combos dict
warp_tile_str = "{\n"
for arch, data in archs.items():
warp_tile_str += f' "{arch}": {{\n'
for dtype, combos in data["warp_tile_combos"].items():
warp_tile_str += f' "{dtype}": {combos},\n'
warp_tile_str += " },\n"
warp_tile_str += "}"
# Build arch family map
arch_family_str = "{\n"
for arch, data in archs.items():
arch_family_str += f' "{arch}": "{data["family"]}",\n'
arch_family_str += "}"
# Build unsupported combos set
unsupported_str = "{\n"
for combo in unsupported:
unsupported_str += f' ("{combo[0]}", "{combo[1]}", "{combo[2]}"),\n'
unsupported_str += "}"
# Pipeline LDS limits
pipeline_limits_clean = {
k: v for k, v in pipeline_limits.items() if not k.startswith("_")
}
# Build dtype combinations dict
dtype_combos = specs.get("dtype_combinations", {})
dtype_combos_str = "{\n"
for key, info in dtype_combos.items():
if not key.startswith("_"):
dtype_combos_str += f' "{key}": {{"acc": "{info["acc"]}", "notes": "{info["notes"]}"}},\n'
dtype_combos_str += "}"
# Build preshuffle warp tile combos dict (operator-specific)
preshuffle_combos = specs.get("preshuffle_warp_tile_combos", {})
preshuffle_warp_tile_str = "{\n"
for arch, dtype_combos_dict in preshuffle_combos.items():
if not arch.startswith("_"):
preshuffle_warp_tile_str += f' "{arch}": {{\n'
for dtype, combos in dtype_combos_dict.items():
preshuffle_warp_tile_str += f' "{dtype}": {combos},\n'
preshuffle_warp_tile_str += " },\n"
preshuffle_warp_tile_str += "}"
# Build preshuffle pipelines list
preshuffle_pipelines = specs.get("preshuffle_pipelines", {}).get(
"supported", ["preshufflev2"]
)
preshuffle_pipelines_str = str(preshuffle_pipelines)
content = f'''# SPDX-License-Identifier: MIT
"""
AUTO-GENERATED FILE - DO NOT EDIT DIRECTLY!
Generated from: arch_specs.json
Generated at: {timestamp}
To update this file:
1. Edit arch_specs.json
2. Run: python generate_arch_specs.py
This module provides architecture-specific configurations for kernel filtering.
"""
from typing import Dict, List, Set, Tuple
# =============================================================================
# Architecture Data (Generated from arch_specs.json)
# =============================================================================
# GPU architecture to family mapping
ARCH_FAMILY_MAP: Dict[str, str] = {arch_family_str}
# Element size in bytes for each data type
ELEMENT_SIZE_MAP: Dict[str, float] = {element_sizes}
# Supported warp configurations per architecture [warp_m, warp_n, warp_k]
WARP_SUPPORTED_COMBINATIONS: Dict[str, List[List[int]]] = {warp_configs_str}
# Supported warp tile combinations: arch -> dtype_key -> [[warp_tile_m, n, k], ...]
WARP_TILE_SUPPORTED_COMBINATIONS: Dict[str, Dict[str, List[List[int]]]] = {warp_tile_str}
# Preshuffle-specific warp tile combinations (subset of standard GEMM)
PRESHUFFLE_WARP_TILE_SUPPORTED_COMBINATIONS: Dict[str, Dict[str, List[List[int]]]] = {preshuffle_warp_tile_str}
# Preshuffle-supported pipelines
PRESHUFFLE_PIPELINES: List[str] = {preshuffle_pipelines_str}
# LDS capacity limits per pipeline type (in bytes)
LDS_CAPACITY_LIMITS: Dict[str, int] = {pipeline_limits_clean}
# Unsupported trait combinations: (pipeline, epilogue, scheduler)
TRAIT_UNSUPPORTED_COMBINATIONS: Set[Tuple[str, str, str]] = {unsupported_str}
# Valid dtype combinations: (A_dtype, B_dtype) -> acc_dtype and notes
DTYPE_COMBINATIONS: Dict[str, Dict[str, str]] = {dtype_combos_str}
# =============================================================================
# Helper Functions
# =============================================================================
def get_supported_archs() -> List[str]:
"""Get list of all supported GPU architectures."""
return list(ARCH_FAMILY_MAP.keys())
def get_arch_family(gpu_arch: str) -> str:
"""Get the GPU family for an architecture."""
return ARCH_FAMILY_MAP.get(gpu_arch.lower(), "unknown")
def get_element_size(dtype: str) -> float:
"""Get element size in bytes for a data type."""
return ELEMENT_SIZE_MAP.get(dtype.lower(), 2.0)
def get_warp_configs(gpu_arch: str) -> List[List[int]]:
"""Get supported warp configurations for an architecture."""
return WARP_SUPPORTED_COMBINATIONS.get(gpu_arch.lower(), [])
def get_warp_tile_combos(gpu_arch: str, dtype_key: str) -> List[List[int]]:
"""Get supported warp tile combinations for arch and data types."""
gpu_combos = WARP_TILE_SUPPORTED_COMBINATIONS.get(gpu_arch.lower(), {{}})
return gpu_combos.get(dtype_key.lower(), [])
def get_lds_limit(pipeline: str) -> int:
"""Get LDS capacity limit for a pipeline type."""
return LDS_CAPACITY_LIMITS.get(pipeline.lower(), LDS_CAPACITY_LIMITS["default"])
def is_trait_combo_unsupported(pipeline: str, epilogue: str, scheduler: str) -> bool:
"""Check if a trait combination is unsupported."""
return (pipeline.lower(), epilogue.lower(), scheduler.lower()) in TRAIT_UNSUPPORTED_COMBINATIONS
def get_dtype_info(dtype_a: str, dtype_b: str) -> Dict[str, str]:
"""Get accumulator type and notes for a dtype combination."""
key = f"{{dtype_a.lower()}}_{{dtype_b.lower()}}"
return DTYPE_COMBINATIONS.get(key, {{"acc": "fp32", "notes": "unknown"}})
def is_dtype_combo_valid(dtype_a: str, dtype_b: str) -> bool:
"""Check if a dtype combination is valid."""
key = f"{{dtype_a.lower()}}_{{dtype_b.lower()}}"
return key in DTYPE_COMBINATIONS
def get_valid_dtype_combos() -> List[str]:
"""Get list of all valid dtype combinations."""
return list(DTYPE_COMBINATIONS.keys())
'''
output_path.write_text(content)
print(f"Generated: {output_path}")
def generate_cpp_header(specs: Dict[str, Any], output_path: Path):
"""Generate C++ header from arch specs."""
timestamp = datetime.now().isoformat()
# Extract data
archs = specs["architectures"]
element_sizes = specs["element_sizes"]
pipeline_limits = specs["pipeline_lds_limits"]
specs["unsupported_trait_combos"]["combinations"]
# Build arch enum and string functions
arch_enums = []
arch_to_string_cases = []
string_to_arch_cases = []
for arch, data in archs.items():
enum_name = arch.upper().replace("GFX", "GFX_")
arch_enums.append(f" {enum_name}, // {data['description']}")
arch_to_string_cases.append(
f' case GpuArch::{enum_name}: return "{arch}";'
)
string_to_arch_cases.append(
f' if (arch_str == "{arch}") return GpuArch::{enum_name};'
)
# Build warp configs switch
warp_config_cases = []
for arch, data in archs.items():
enum_name = arch.upper().replace("GFX", "GFX_")
configs = ", ".join(
[f"{{{c[0]}, {c[1]}, {c[2]}}}" for c in data["warp_configs"]]
)
warp_config_cases.append(
f" case GpuArch::{enum_name}: return {{{configs}}};"
)
# Build element size switch
# Include all data types defined in kernel_key.hpp DataType enum
elem_size_cases = []
dtype_enum_map = {
"fp16": "FP16",
"bf16": "BF16",
"fp32": "FP32",
"fp64": "FP64",
"fp8": "FP8",
"bf8": "BF8",
"int8": "INT8",
"int4": "INT4",
"int32": "INT32",
}
for dtype, size in element_sizes.items():
if dtype in dtype_enum_map:
elem_size_cases.append(
f" case DataType::{dtype_enum_map[dtype]}: return {float(size)}f;"
)
# Build LDS limits
lds_limit_cases = []
pipeline_enum_map = {
"mem": "Mem",
"compv1": "CompV1",
"compv2": "CompV2",
"compv3": "CompV3",
"compv4": "CompV4",
"compv5": "CompV5",
"preshufflev1": "PreShuffleV1",
"preshufflev2": "PreShuffleV2",
}
default_lds = pipeline_limits.get("default", 65536)
for pipeline, limit in pipeline_limits.items():
if pipeline in pipeline_enum_map:
lds_limit_cases.append(
f" if (pipeline == Pipeline::{pipeline_enum_map[pipeline]}) return {limit};"
)
content = f"""// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
/**
* AUTO-GENERATED FILE - DO NOT EDIT DIRECTLY!
*
* Generated from: arch_specs.json
* Generated at: {timestamp}
*
* To update this file:
* 1. Edit arch_specs.json
* 2. Run: python generate_arch_specs.py
*/
#pragma once
#include "ck_tile/dispatcher/kernel_key.hpp"
#include <array>
#include <string>
#include <vector>
#include <cstdint>
namespace ck_tile {{
namespace dispatcher {{
namespace arch_specs {{
// =============================================================================
// GPU Architecture Enum (Generated)
// =============================================================================
enum class GpuArch : std::uint8_t {{
{chr(10).join(arch_enums)}
UNKNOWN
}};
// =============================================================================
// String Conversion Functions (Generated)
// =============================================================================
inline std::string arch_to_string(GpuArch arch) {{
switch (arch) {{
{chr(10).join(arch_to_string_cases)}
default: return "unknown";
}}
}}
inline GpuArch string_to_arch(const std::string& arch_str) {{
{chr(10).join(string_to_arch_cases)}
return GpuArch::UNKNOWN;
}}
// =============================================================================
// Element Size (Generated)
// =============================================================================
inline float element_size(DataType dtype) {{
switch (dtype) {{
{chr(10).join(elem_size_cases)}
default: return 2.0f;
}}
}}
// =============================================================================
// Warp Configurations (Generated)
// =============================================================================
using WarpConfig = std::array<int, 3>;
inline std::vector<WarpConfig> get_supported_warp_configs(GpuArch arch) {{
switch (arch) {{
{chr(10).join(warp_config_cases)}
default: return {{}};
}}
}}
// =============================================================================
// LDS Capacity Limits (Generated)
// =============================================================================
inline std::size_t get_lds_capacity(Pipeline pipeline) {{
{chr(10).join(lds_limit_cases)}
return {default_lds}; // Default
}}
// =============================================================================
// Unsupported Trait Combinations (Generated)
// =============================================================================
inline bool is_trait_unsupported(Pipeline pipeline, [[maybe_unused]] Epilogue epilogue, Scheduler scheduler) {{
// Generated from unsupported_trait_combos in arch_specs.json
if (scheduler == Scheduler::Interwave) {{
if (pipeline == Pipeline::CompV3 || pipeline == Pipeline::CompV4) {{
return true;
}}
}}
return false;
}}
}} // namespace arch_specs
}} // namespace dispatcher
}} // namespace ck_tile
"""
output_path.write_text(content)
print(f"Generated: {output_path}")
def main():
parser = argparse.ArgumentParser(
description="Generate Python and C++ code from arch_specs.json"
)
parser.add_argument(
"--json",
type=Path,
default=SCRIPT_DIR / "arch_specs.json",
help="Path to arch_specs.json",
)
parser.add_argument(
"--output-dir",
type=Path,
default=SCRIPT_DIR,
help="Output directory for generated files",
)
parser.add_argument(
"--cpp-output-dir",
type=Path,
default=None,
help="Output directory for C++ header (defaults to dispatcher/include/...)",
)
args = parser.parse_args()
# Load specs
print(f"Loading: {args.json}")
specs = load_arch_specs(args.json)
# Generate Python module
py_output = args.output_dir / "arch_specs_generated.py"
generate_python_module(specs, py_output)
# Generate C++ header
if args.cpp_output_dir:
cpp_output = args.cpp_output_dir / "arch_specs_generated.hpp"
else:
cpp_output = (
SCRIPT_DIR.parent
/ "include"
/ "ck_tile"
/ "dispatcher"
/ "arch_specs_generated.hpp"
)
cpp_output.parent.mkdir(parents=True, exist_ok=True)
generate_cpp_header(specs, cpp_output)
print("\nDone! To apply changes:")
print(" 1. Python code will automatically use arch_specs_generated.py")
print(" 2. C++ code includes arch_specs_generated.hpp")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,429 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Generate dispatcher registration code for CK Tile kernels
This script generates C++ registration code that instantiates TileKernelInstance
templates for each generated kernel, solving the "cannot instantiate from parsed headers" problem.
"""
import json
import argparse
from pathlib import Path
from typing import List
from dataclasses import dataclass
@dataclass
class KernelConfig:
"""Kernel configuration for registration"""
name: str
header_file: str
tile_m: int
tile_n: int
tile_k: int
warp_m: int
warp_n: int
warp_k: int
warp_tile_m: int
warp_tile_n: int
warp_tile_k: int
block_size: int
pipeline: str
epilogue: str
scheduler: str
pad_m: bool
pad_n: bool
pad_k: bool
persistent: bool
double_buffer: bool
transpose_c: bool
dtype_a: str = "fp16"
dtype_b: str = "fp16"
dtype_c: str = "fp16"
dtype_acc: str = "fp32"
layout_a: str = "row"
layout_b: str = "col"
layout_c: str = "row"
def generate_registration_header(kernels: List[KernelConfig], output_file: Path):
"""Generate registration header file"""
content = """// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
//
// AUTO-GENERATED FILE - DO NOT EDIT
// Generated by generate_dispatcher_registration.py
#pragma once
#include "ck_tile/dispatcher/registry.hpp"
#include "ck_tile/dispatcher/backends/tile_backend.hpp"
#include "ck_tile/dispatcher/backends/kernel_registration.hpp"
// Include all generated kernel headers
"""
# Add includes for all kernel headers
for kernel in kernels:
content += f'#include "{kernel.header_file}"\n'
content += """
namespace ck_tile {
namespace dispatcher {
namespace generated {
/// Register all generated kernels with the dispatcher
inline void register_all_kernels(Registry& registry)
{
"""
# Add registration calls for each kernel
for kernel in kernels:
# Extract the SelectedKernel type name from the header file
# Assuming the header defines a type like: using SelectedKernel = ...
kernel_type = f"SelectedKernel_{kernel.name}"
content += f""" // Register {kernel.name}
register_tile_kernel<{kernel_type}>(registry, "{kernel.name}");
"""
content += """}
/// Register all generated kernels with the global registry
inline void register_all_kernels()
{
auto& registry = Registry::instance();
register_all_kernels(registry);
}
} // namespace generated
} // namespace dispatcher
} // namespace ck_tile
"""
output_file.write_text(content)
print(f"✓ Generated registration header: {output_file}")
def generate_registration_cpp(kernels: List[KernelConfig], output_file: Path):
"""Generate registration implementation file"""
content = """// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
//
// AUTO-GENERATED FILE - DO NOT EDIT
// Generated by generate_dispatcher_registration.py
#include "dispatcher_registration.hpp"
namespace ck_tile {
namespace dispatcher {
namespace generated {
// Explicit instantiations to reduce compile time
// These ensure the templates are instantiated once
"""
for kernel in kernels:
kernel_type = f"SelectedKernel_{kernel.name}"
content += f"template class backends::TileKernelInstance<{kernel_type}>;\n"
content += """
} // namespace generated
} // namespace dispatcher
} // namespace ck_tile
"""
output_file.write_text(content)
print(f"✓ Generated registration implementation: {output_file}")
def generate_kernel_wrapper_header(kernel: KernelConfig, output_dir: Path):
"""Generate a wrapper header that defines SelectedKernel type"""
wrapper_file = output_dir / f"{kernel.name}_wrapper.hpp"
content = f"""// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
//
// AUTO-GENERATED FILE - DO NOT EDIT
// Generated by generate_dispatcher_registration.py
#pragma once
#include "{kernel.header_file}"
namespace ck_tile {{
namespace dispatcher {{
namespace generated {{
// Type alias for dispatcher registration
// This allows the registration code to reference the kernel type
using SelectedKernel_{kernel.name} = /* Actual kernel type from generated header */;
}} // namespace generated
}} // namespace dispatcher
}} // namespace ck_tile
"""
wrapper_file.write_text(content)
def load_kernel_manifest(manifest_file: Path) -> List[KernelConfig]:
"""Load kernel configurations from manifest file"""
with open(manifest_file, "r") as f:
data = json.load(f)
kernels = []
for kernel_data in data.get("kernels", []):
kernel = KernelConfig(
name=kernel_data["name"],
header_file=kernel_data["header_file"],
tile_m=kernel_data["tile_m"],
tile_n=kernel_data["tile_n"],
tile_k=kernel_data["tile_k"],
warp_m=kernel_data.get("warp_m", 2),
warp_n=kernel_data.get("warp_n", 2),
warp_k=kernel_data.get("warp_k", 1),
warp_tile_m=kernel_data.get("warp_tile_m", 32),
warp_tile_n=kernel_data.get("warp_tile_n", 32),
warp_tile_k=kernel_data.get("warp_tile_k", 16),
block_size=kernel_data.get("block_size", 256),
pipeline=kernel_data.get("pipeline", "compv4"),
epilogue=kernel_data.get("epilogue", "cshuffle"),
scheduler=kernel_data.get("scheduler", "intrawave"),
pad_m=kernel_data.get("pad_m", False),
pad_n=kernel_data.get("pad_n", False),
pad_k=kernel_data.get("pad_k", False),
persistent=kernel_data.get("persistent", False),
double_buffer=kernel_data.get("double_buffer", True),
transpose_c=kernel_data.get("transpose_c", False),
dtype_a=kernel_data.get("dtype_a", "fp16"),
dtype_b=kernel_data.get("dtype_b", "fp16"),
dtype_c=kernel_data.get("dtype_c", "fp16"),
dtype_acc=kernel_data.get("dtype_acc", "fp32"),
)
kernels.append(kernel)
return kernels
def scan_generated_headers(generated_dir: Path) -> List[KernelConfig]:
"""Scan generated headers and extract kernel configurations"""
import re
kernels = []
for header_file in generated_dir.glob("**/*.hpp"):
try:
content = header_file.read_text()
# Extract kernel name
name_match = re.search(
r'constexpr const char\* KERNEL_NAME\s*=\s*"([^"]+)"', content
)
if not name_match:
continue
kernel_name = name_match.group(1)
# Extract tile configuration (support ck_tile::index_t)
tile_m_match = re.search(
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+TileM\s*=\s*(\d+)",
content,
)
tile_n_match = re.search(
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+TileN\s*=\s*(\d+)",
content,
)
tile_k_match = re.search(
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+TileK\s*=\s*(\d+)",
content,
)
tile_m = int(tile_m_match.group(1)) if tile_m_match else 256
tile_n = int(tile_n_match.group(1)) if tile_n_match else 256
tile_k = int(tile_k_match.group(1)) if tile_k_match else 32
# Extract warp configuration
warp_m_match = re.search(
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpPerBlock_M\s*=\s*(\d+)",
content,
)
warp_n_match = re.search(
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpPerBlock_N\s*=\s*(\d+)",
content,
)
warp_k_match = re.search(
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpPerBlock_K\s*=\s*(\d+)",
content,
)
warp_m = int(warp_m_match.group(1)) if warp_m_match else 2
warp_n = int(warp_n_match.group(1)) if warp_n_match else 2
warp_k = int(warp_k_match.group(1)) if warp_k_match else 1
# Extract warp tile configuration
warp_tile_m_match = re.search(
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpTileM\s*=\s*(\d+)",
content,
)
warp_tile_n_match = re.search(
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpTileN\s*=\s*(\d+)",
content,
)
warp_tile_k_match = re.search(
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+WarpTileK\s*=\s*(\d+)",
content,
)
warp_tile_m = int(warp_tile_m_match.group(1)) if warp_tile_m_match else 32
warp_tile_n = int(warp_tile_n_match.group(1)) if warp_tile_n_match else 32
warp_tile_k = int(warp_tile_k_match.group(1)) if warp_tile_k_match else 16
# Extract other parameters (with defaults)
block_size_match = re.search(
r"(?:static\s+)?constexpr\s+(?:int|std::size_t|ck_tile::index_t)\s+BlockSize\s*=\s*(\d+)",
content,
)
block_size = int(block_size_match.group(1)) if block_size_match else 256
# Extract boolean flags
pad_m = re.search(r"kPadM\s*=\s*true", content) is not None
pad_n = re.search(r"kPadN\s*=\s*true", content) is not None
pad_k = re.search(r"kPadK\s*=\s*true", content) is not None
persistent = (
re.search(r"UsePersistentKernel\s*=\s*true", content) is not None
)
double_buffer = (
re.search(r"DoubleSmemBuffer\s*=\s*true", content) is not None
)
transpose_c = re.search(r"TransposeC\s*=\s*true", content) is not None
kernel = KernelConfig(
name=kernel_name,
header_file=str(header_file.relative_to(generated_dir.parent)),
tile_m=tile_m,
tile_n=tile_n,
tile_k=tile_k,
warp_m=warp_m,
warp_n=warp_n,
warp_k=warp_k,
warp_tile_m=warp_tile_m,
warp_tile_n=warp_tile_n,
warp_tile_k=warp_tile_k,
block_size=block_size,
pipeline="compv4",
epilogue="cshuffle",
scheduler="intrawave",
pad_m=pad_m,
pad_n=pad_n,
pad_k=pad_k,
persistent=persistent,
double_buffer=double_buffer,
transpose_c=transpose_c,
)
kernels.append(kernel)
except Exception as e:
print(f"Warning: Failed to parse {header_file}: {e}")
continue
return kernels
def main():
parser = argparse.ArgumentParser(
description="Generate dispatcher registration code"
)
parser.add_argument(
"--generated-dir",
type=str,
required=True,
help="Directory containing generated kernel headers",
)
parser.add_argument(
"--output-dir",
type=str,
required=True,
help="Output directory for registration code",
)
parser.add_argument(
"--manifest", type=str, help="Optional manifest file with kernel configurations"
)
parser.add_argument(
"--scan",
action="store_true",
help="Scan generated headers instead of using manifest",
)
args = parser.parse_args()
generated_dir = Path(args.generated_dir)
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Load kernel configurations
if args.manifest:
print(f"Loading kernels from manifest: {args.manifest}")
kernels = load_kernel_manifest(Path(args.manifest))
elif args.scan:
print(f"Scanning generated headers in: {generated_dir}")
kernels = scan_generated_headers(generated_dir)
else:
print("Error: Must specify either --manifest or --scan")
return 1
print(f"Found {len(kernels)} kernels")
# Generate registration code
registration_header = output_dir / "dispatcher_registration.hpp"
registration_cpp = output_dir / "dispatcher_registration.cpp"
generate_registration_header(kernels, registration_header)
generate_registration_cpp(kernels, registration_cpp)
# Generate manifest for Python
manifest_output = output_dir / "kernels_manifest.json"
manifest_data = {
"kernels": [
{
"name": k.name,
"header_file": k.header_file,
"tile_m": k.tile_m,
"tile_n": k.tile_n,
"tile_k": k.tile_k,
"block_size": k.block_size,
"persistent": k.persistent,
}
for k in kernels
]
}
with open(manifest_output, "w") as f:
json.dump(manifest_data, f, indent=2)
print(f"✓ Generated manifest: {manifest_output}")
print("\n✓ Registration code generation complete!")
print(f" Total kernels: {len(kernels)}")
print(" Output files:")
print(f" - {registration_header}")
print(f" - {registration_cpp}")
print(f" - {manifest_output}")
return 0
if __name__ == "__main__":
exit(main())

View File

@@ -0,0 +1,430 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Generate one .cpp wrapper file per kernel header for maximum parallel compilation.
Each kernel becomes its own translation unit, enabling:
- Maximum parallelism with make -j$(nproc)
- Per-kernel build progress (e.g., [5/128] Building kernel: gemm_fp16_128x128)
- Incremental rebuilds (only changed kernels recompile)
- Fine-grained build time analysis
Usage:
python3 generate_kernel_wrappers.py --kernel-dir build/generated_kernels --output-dir build/kernel_wrappers
Output structure:
build/kernel_wrappers/
├── gemm_fp16_rcr_128x128x32.cpp
├── gemm_fp16_rcr_256x256x64.cpp
├── conv_fwd_fp16_2d_128x128.cpp
└── ...
Each .cpp simply includes its corresponding .hpp and forces symbol emission.
"""
import argparse
import sys
from pathlib import Path
from typing import List, Tuple
import concurrent.futures
WRAPPER_TEMPLATE_GEMM = """// SPDX-License-Identifier: MIT
// Auto-generated wrapper for: {kernel_name}
// This file enables per-kernel parallel compilation
#include "{kernel_hpp}"
// Force symbol emission for kernel registration
namespace ck_tile {{
namespace dispatcher {{
namespace generated {{
// Marker to prevent dead code elimination
volatile bool _{kernel_id}_registered = true;
}} // namespace generated
}} // namespace dispatcher
}} // namespace ck_tile
"""
WRAPPER_TEMPLATE_CONV = """// SPDX-License-Identifier: MIT
// Auto-generated wrapper for: {kernel_name}
// This file enables per-kernel parallel compilation
#include "{kernel_hpp}"
namespace ck_tile {{
namespace dispatcher {{
namespace generated {{
volatile bool _{kernel_id}_registered = true;
}} // namespace generated
}} // namespace dispatcher
}} // namespace ck_tile
"""
def generate_wrapper(
kernel_hpp: Path, output_dir: Path, index: int, total: int
) -> Tuple[Path, bool]:
"""Generate a .cpp wrapper for a single kernel header."""
kernel_name = kernel_hpp.stem
kernel_id = kernel_name.replace("-", "_").replace(".", "_")
# Select template based on kernel type
if kernel_name.startswith("gemm"):
template = WRAPPER_TEMPLATE_GEMM
else:
template = WRAPPER_TEMPLATE_CONV
content = template.format(
kernel_name=kernel_name,
kernel_hpp=kernel_hpp.name,
kernel_id=kernel_id,
)
output_cpp = output_dir / f"{kernel_name}.cpp"
# Only write if content changed (for incremental builds)
if output_cpp.exists():
existing = output_cpp.read_text()
if existing == content:
return output_cpp, False # No change
output_cpp.write_text(content)
return output_cpp, True # Written
def generate_cmake_list(
wrappers: List[Path], output_dir: Path, kernel_dir: Path
) -> Path:
"""Generate CMakeLists.txt that compiles each wrapper as a separate object."""
num_kernels = len(wrappers)
cmake_content = f'''# SPDX-License-Identifier: MIT
# Auto-generated CMakeLists.txt for per-kernel parallel compilation
# Generated {num_kernels} kernel translation units
cmake_minimum_required(VERSION 3.16)
# =============================================================================
# Per-Kernel Object Targets ({num_kernels} kernels)
# =============================================================================
# Each kernel is compiled as a separate OBJECT library for maximum parallelism.
# Build with: make -j$(nproc) all_kernels
#
# Progress output:
# [ 1/{num_kernels}] Building kernel: gemm_fp16_rcr_128x128x32
# [ 2/{num_kernels}] Building kernel: gemm_fp16_rcr_256x256x64
# ...
set(KERNEL_INCLUDE_DIR "{kernel_dir}")
set(ALL_KERNEL_OBJECTS "")
'''
for idx, wrapper in enumerate(wrappers, 1):
kernel_name = wrapper.stem
obj_target = f"kobj_{kernel_name}"
cmake_content += f"""
# [{idx}/{num_kernels}] {kernel_name}
add_library({obj_target} OBJECT {wrapper.name})
target_include_directories({obj_target} PRIVATE ${{KERNEL_INCLUDE_DIR}} ${{CK_INCLUDE_DIR}})
target_compile_options({obj_target} PRIVATE
-mllvm -enable-noalias-to-md-conversion=0
-Wno-undefined-func-template
-Wno-float-equal
--offload-compress
)
set_target_properties({obj_target} PROPERTIES POSITION_INDEPENDENT_CODE ON)
if(hip_FOUND)
target_link_libraries({obj_target} PRIVATE hip::device hip::host)
endif()
list(APPEND ALL_KERNEL_OBJECTS $<TARGET_OBJECTS:{obj_target}>)
"""
cmake_content += f"""
# =============================================================================
# Combined Kernel Library
# =============================================================================
# Links all {num_kernels} kernel objects into a single shared library
add_library(all_kernels SHARED ${{ALL_KERNEL_OBJECTS}})
if(hip_FOUND)
target_link_libraries(all_kernels PRIVATE hip::device hip::host)
endif()
set_target_properties(all_kernels PROPERTIES
POSITION_INDEPENDENT_CODE ON
OUTPUT_NAME "dispatcher_kernels"
)
message(STATUS "Configured {num_kernels} kernel objects for parallel compilation")
message(STATUS "Build with: make -j$(nproc) all_kernels")
"""
cmake_file = output_dir / "CMakeLists.txt"
cmake_file.write_text(cmake_content)
return cmake_file
def generate_ninja_build(
wrappers: List[Path], output_dir: Path, kernel_dir: Path
) -> Path:
"""Generate build.ninja for even faster parallel compilation."""
num_kernels = len(wrappers)
ninja_content = f"""# SPDX-License-Identifier: MIT
# Auto-generated build.ninja for per-kernel parallel compilation
# {num_kernels} kernel translation units
# Variables
cxx = hipcc
cxxflags = -fPIC -std=c++17 -O3 -mllvm -enable-noalias-to-md-conversion=0 -Wno-undefined-func-template -Wno-float-equal --offload-compress
includes = -I{kernel_dir} -I/opt/rocm/include
# Rules
rule compile
command = $cxx $cxxflags $includes -c $in -o $out
description = [{num_kernels}] Building kernel: $kernel_name
rule link
command = $cxx -shared $in -o $out -L/opt/rocm/lib -lamdhip64
description = Linking: $out
# Kernel objects
"""
obj_files = []
for idx, wrapper in enumerate(wrappers, 1):
kernel_name = wrapper.stem
obj_file = f"{kernel_name}.o"
obj_files.append(obj_file)
ninja_content += f"""
build {obj_file}: compile {wrapper.name}
kernel_name = {kernel_name}
"""
ninja_content += f"""
# Shared library
build libdispatcher_kernels.so: link {" ".join(obj_files)}
# Default target
default libdispatcher_kernels.so
"""
ninja_file = output_dir / "build.ninja"
ninja_file.write_text(ninja_content)
return ninja_file
def generate_makefile(wrappers: List[Path], output_dir: Path, kernel_dir: Path) -> Path:
"""Generate Makefile for per-kernel parallel compilation."""
num_kernels = len(wrappers)
kernel_names = [w.stem for w in wrappers]
obj_files = [f"{name}.o" for name in kernel_names]
makefile_content = f"""# SPDX-License-Identifier: MIT
# Auto-generated Makefile for per-kernel parallel compilation
# {num_kernels} kernel translation units
#
# Usage:
# make -j$(nproc) # Build all kernels in parallel
# make -j$(nproc) VERBOSE=1 # With per-kernel progress
# make clean # Remove all objects
CXX = hipcc
CXXFLAGS = -fPIC -std=c++17 -O3 -mllvm -enable-noalias-to-md-conversion=0 \\
-Wno-undefined-func-template -Wno-float-equal --offload-compress
INCLUDES = -I{kernel_dir} -I/opt/rocm/include
LDFLAGS = -shared -L/opt/rocm/lib -lamdhip64
TARGET = libdispatcher_kernels.so
OBJECTS = {" ".join(obj_files)}
# Progress counter (only works with make -j1, use ninja for parallel progress)
TOTAL_KERNELS = {num_kernels}
CURRENT = 0
.PHONY: all clean
all: $(TARGET)
\t@echo "Built $(TARGET) with {num_kernels} kernels"
$(TARGET): $(OBJECTS)
\t@echo "[LINK] Linking {num_kernels} kernel objects -> $@"
\t$(CXX) $(LDFLAGS) $^ -o $@
"""
for idx, (wrapper, obj) in enumerate(zip(wrappers, obj_files), 1):
kernel_name = wrapper.stem
makefile_content += f"""
{obj}: {wrapper.name}
\t@echo "[{idx}/{num_kernels}] Building kernel: {kernel_name}"
\t$(CXX) $(CXXFLAGS) $(INCLUDES) -c $< -o $@
"""
makefile_content += f"""
clean:
\trm -f $(OBJECTS) $(TARGET)
\t@echo "Cleaned {num_kernels} kernel objects"
"""
makefile = output_dir / "Makefile"
makefile.write_text(makefile_content)
return makefile
def main():
parser = argparse.ArgumentParser(
description="Generate per-kernel wrapper .cpp files for parallel compilation"
)
parser.add_argument(
"--kernel-dir",
type=Path,
required=True,
help="Directory containing generated kernel .hpp files",
)
parser.add_argument(
"--output-dir",
type=Path,
required=True,
help="Output directory for wrapper .cpp files",
)
parser.add_argument(
"--pattern",
type=str,
default="*.hpp",
help="Glob pattern for kernel headers (default: *.hpp)",
)
parser.add_argument(
"--generate-cmake",
action="store_true",
help="Generate CMakeLists.txt for the wrappers",
)
parser.add_argument(
"--generate-ninja",
action="store_true",
help="Generate build.ninja for ninja builds",
)
parser.add_argument(
"--generate-makefile",
action="store_true",
help="Generate Makefile for make builds",
)
parser.add_argument(
"--parallel",
action="store_true",
default=True,
help="Generate wrappers in parallel (default: True)",
)
parser.add_argument(
"-v",
"--verbose",
action="store_true",
help="Verbose output",
)
args = parser.parse_args()
# Find kernel headers
kernel_dir = args.kernel_dir.resolve()
if not kernel_dir.exists():
print(f"Error: Kernel directory not found: {kernel_dir}", file=sys.stderr)
return 1
kernel_headers = sorted(kernel_dir.glob(args.pattern))
if not kernel_headers:
print(
f"Error: No kernel headers found matching {args.pattern} in {kernel_dir}",
file=sys.stderr,
)
return 1
num_kernels = len(kernel_headers)
print(f"Found {num_kernels} kernel headers in {kernel_dir}")
# Create output directory
output_dir = args.output_dir.resolve()
output_dir.mkdir(parents=True, exist_ok=True)
# Generate wrappers
print(f"Generating {num_kernels} wrapper .cpp files...")
wrappers = []
written = 0
if args.parallel and num_kernels > 1:
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = {
executor.submit(
generate_wrapper, hpp, output_dir, idx, num_kernels
): hpp
for idx, hpp in enumerate(kernel_headers, 1)
}
for future in concurrent.futures.as_completed(futures):
wrapper_path, was_written = future.result()
wrappers.append(wrapper_path)
if was_written:
written += 1
if args.verbose:
print(f" Generated: {wrapper_path.name}")
else:
for idx, hpp in enumerate(kernel_headers, 1):
wrapper_path, was_written = generate_wrapper(
hpp, output_dir, idx, num_kernels
)
wrappers.append(wrapper_path)
if was_written:
written += 1
if args.verbose:
print(f" [{idx}/{num_kernels}] Generated: {wrapper_path.name}")
wrappers.sort(key=lambda p: p.name)
print(
f" Total: {num_kernels} wrappers ({written} written, {num_kernels - written} unchanged)"
)
# Generate build files
if args.generate_cmake:
cmake_file = generate_cmake_list(wrappers, output_dir, kernel_dir)
print(f" Generated: {cmake_file}")
if args.generate_ninja:
ninja_file = generate_ninja_build(wrappers, output_dir, kernel_dir)
print(f" Generated: {ninja_file}")
if args.generate_makefile:
makefile = generate_makefile(wrappers, output_dir, kernel_dir)
print(f" Generated: {makefile}")
print(f"\nOutput directory: {output_dir}")
print(f"Kernels ready for parallel compilation: {num_kernels}")
print("\nTo build:")
print(f" cd {output_dir}")
if args.generate_makefile:
print(" make -j$(nproc) # Parallel build with progress")
if args.generate_ninja:
print(" ninja # Fast parallel build")
if args.generate_cmake:
print(" cmake -B build && cmake --build build -j$(nproc)")
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,798 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Kernel Configuration Loader
Load kernel configurations from JSON files for generating specific kernel sets.
Compatible with tile_engine JSON format.
Usage:
from kernel_config_loader import load_kernel_configs, KernelConfigSet
# Load configs from JSON
config_set = load_kernel_configs("my_kernels.json")
# Get all configurations (cartesian product of all parameter values)
for config in config_set.generate_configs():
print(config)
# Use with codegen
from unified_gemm_codegen import UnifiedGemmCodegen
codegen = UnifiedGemmCodegen(...)
codegen.generate_from_configs(config_set.generate_configs())
"""
import json
import itertools
from dataclasses import dataclass, field
from pathlib import Path
from typing import List, Dict, Any, Optional, Iterator
@dataclass
class TileConfig:
"""Tile configuration for a kernel"""
tile_m: int = 128
tile_n: int = 128
tile_k: int = 32
warp_m: int = 2
warp_n: int = 2
warp_k: int = 1
warp_tile_m: int = 32
warp_tile_n: int = 32
warp_tile_k: int = 16
@dataclass
class TraitConfig:
"""Trait configuration for a kernel (order matches GEMM/Conv TraitConfig)"""
pipeline: str = "compv4"
epilogue: str = "cshuffle"
scheduler: str = "intrawave"
pad_m: bool = False
pad_n: bool = False
pad_k: bool = False
@dataclass
class KernelConfig:
"""Complete kernel configuration"""
tile: TileConfig = field(default_factory=TileConfig)
trait: TraitConfig = field(default_factory=TraitConfig)
dtype_a: str = "fp16"
dtype_b: str = "fp16"
dtype_c: str = "fp16"
dtype_acc: str = "fp32"
layout: str = "rcr"
gpu_target: str = "gfx942"
variant: str = "standard"
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for codegen"""
return {
"tile_m": self.tile.tile_m,
"tile_n": self.tile.tile_n,
"tile_k": self.tile.tile_k,
"warp_m": self.tile.warp_m,
"warp_n": self.tile.warp_n,
"warp_k": self.tile.warp_k,
"warp_tile_m": self.tile.warp_tile_m,
"warp_tile_n": self.tile.warp_tile_n,
"warp_tile_k": self.tile.warp_tile_k,
"pipeline": self.trait.pipeline,
"scheduler": self.trait.scheduler,
"epilogue": self.trait.epilogue,
"pad_m": self.trait.pad_m,
"pad_n": self.trait.pad_n,
"pad_k": self.trait.pad_k,
"dtype_a": self.dtype_a,
"dtype_b": self.dtype_b,
"dtype_c": self.dtype_c,
"dtype_acc": self.dtype_acc,
"layout": self.layout,
"gpu_target": self.gpu_target,
"variant": self.variant,
}
def kernel_name(self) -> str:
"""Generate kernel name from config"""
name = f"gemm_{self.dtype_a}_{self.layout}_{self.trait.pipeline}"
name += f"_{self.trait.epilogue}_{self.trait.scheduler}"
name += f"_{str(self.trait.pad_m).capitalize()}"
name += f"_{str(self.trait.pad_n).capitalize()}"
name += f"_{str(self.trait.pad_k).capitalize()}"
name += "_False" # preshuffle
name += f"_{self.tile.tile_m}x{self.tile.tile_n}x{self.tile.tile_k}"
name += f"_{self.tile.warp_m}x{self.tile.warp_n}x{self.tile.warp_k}"
name += (
f"_{self.tile.warp_tile_m}x{self.tile.warp_tile_n}x{self.tile.warp_tile_k}"
)
return name
@dataclass
class KernelConfigSet:
"""A set of kernel configurations loaded from JSON"""
name: str = "default"
configs: List[KernelConfig] = field(default_factory=list)
# Parameter ranges for generation
tile_m_values: List[int] = field(default_factory=lambda: [128])
tile_n_values: List[int] = field(default_factory=lambda: [128])
tile_k_values: List[int] = field(default_factory=lambda: [32])
warp_m_values: List[int] = field(default_factory=lambda: [2])
warp_n_values: List[int] = field(default_factory=lambda: [2])
warp_k_values: List[int] = field(default_factory=lambda: [1])
warp_tile_m_values: List[int] = field(default_factory=lambda: [32])
warp_tile_n_values: List[int] = field(default_factory=lambda: [32])
warp_tile_k_values: List[int] = field(default_factory=lambda: [16])
pipeline_values: List[str] = field(default_factory=lambda: ["compv4"])
scheduler_values: List[str] = field(default_factory=lambda: ["intrawave"])
epilogue_values: List[str] = field(default_factory=lambda: ["cshuffle"])
pad_m_values: List[bool] = field(default_factory=lambda: [False])
pad_n_values: List[bool] = field(default_factory=lambda: [False])
pad_k_values: List[bool] = field(default_factory=lambda: [False])
dtype_a: str = "fp16"
dtype_b: str = "fp16"
dtype_c: str = "fp16"
dtype_acc: str = "fp32"
layout: str = "rcr"
gpu_targets: List[str] = field(default_factory=lambda: ["gfx942"])
variant: str = "standard"
def generate_configs(self) -> Iterator[KernelConfig]:
"""Generate all kernel configurations (cartesian product)"""
# Tile parameters
tile_params = itertools.product(
self.tile_m_values,
self.tile_n_values,
self.tile_k_values,
self.warp_m_values,
self.warp_n_values,
self.warp_k_values,
self.warp_tile_m_values,
self.warp_tile_n_values,
self.warp_tile_k_values,
)
# Trait parameters
trait_params = itertools.product(
self.pipeline_values,
self.scheduler_values,
self.epilogue_values,
self.pad_m_values,
self.pad_n_values,
self.pad_k_values,
)
# Convert to lists for reuse
tile_list = list(tile_params)
trait_list = list(trait_params)
# Generate for each GPU target
for gpu_target in self.gpu_targets:
for tile in tile_list:
for trait in trait_list:
tile_cfg = TileConfig(
tile_m=tile[0],
tile_n=tile[1],
tile_k=tile[2],
warp_m=tile[3],
warp_n=tile[4],
warp_k=tile[5],
warp_tile_m=tile[6],
warp_tile_n=tile[7],
warp_tile_k=tile[8],
)
trait_cfg = TraitConfig(
pipeline=trait[0],
scheduler=trait[1],
epilogue=trait[2],
pad_m=trait[3],
pad_n=trait[4],
pad_k=trait[5],
)
yield KernelConfig(
tile=tile_cfg,
trait=trait_cfg,
dtype_a=self.dtype_a,
dtype_b=self.dtype_b,
dtype_c=self.dtype_c,
dtype_acc=self.dtype_acc,
layout=self.layout,
gpu_target=gpu_target,
variant=self.variant,
)
def config_count(self) -> int:
"""Get total number of configurations"""
tile_count = (
len(self.tile_m_values)
* len(self.tile_n_values)
* len(self.tile_k_values)
* len(self.warp_m_values)
* len(self.warp_n_values)
* len(self.warp_k_values)
* len(self.warp_tile_m_values)
* len(self.warp_tile_n_values)
* len(self.warp_tile_k_values)
)
trait_count = (
len(self.pipeline_values)
* len(self.scheduler_values)
* len(self.epilogue_values)
* len(self.pad_m_values)
* len(self.pad_n_values)
* len(self.pad_k_values)
)
return tile_count * trait_count * len(self.gpu_targets)
def _get_values(config: Dict, key: str, default: List) -> List:
"""Extract values from config dict, handling range specifications"""
if key not in config:
return default
item = config[key]
# Explicit values list
if "values" in item:
return item["values"]
# Range specification (min, max, step)
if "min" in item and "max" in item:
min_val = item["min"]
max_val = item["max"]
step = item.get("step", 1)
return list(range(min_val, max_val + 1, step))
return default
def load_kernel_configs(json_path: str | Path) -> KernelConfigSet:
"""
Load kernel configurations from a JSON file.
Supports both tile_engine format and dispatcher format.
Args:
json_path: Path to JSON configuration file
Returns:
KernelConfigSet with all parameter values loaded
"""
json_path = Path(json_path)
with open(json_path) as f:
data = json.load(f)
config_set = KernelConfigSet()
# Name
config_set.name = data.get("kernel_set_name", json_path.stem)
# Data types
if "datatype" in data:
dt = data["datatype"]
config_set.dtype_a = dt.get("a", "fp16")
config_set.dtype_b = dt.get("b", "fp16")
config_set.dtype_c = dt.get("c", "fp16")
config_set.dtype_acc = dt.get("acc", "fp32")
# Layout
config_set.layout = data.get("layout", "rcr")
# GPU targets
if "gpu_targets" in data:
config_set.gpu_targets = data["gpu_targets"]
elif "gpu_target" in data:
config_set.gpu_targets = [data["gpu_target"]]
# Variant
config_set.variant = data.get("variant", "standard")
# Tile config
tile_cfg = data.get("tile_config", {})
config_set.tile_m_values = _get_values(tile_cfg, "tile_m", [128])
config_set.tile_n_values = _get_values(tile_cfg, "tile_n", [128])
config_set.tile_k_values = _get_values(tile_cfg, "tile_k", [32])
config_set.warp_m_values = _get_values(tile_cfg, "warp_m", [2])
config_set.warp_n_values = _get_values(tile_cfg, "warp_n", [2])
config_set.warp_k_values = _get_values(tile_cfg, "warp_k", [1])
config_set.warp_tile_m_values = _get_values(tile_cfg, "warp_tile_m", [32])
config_set.warp_tile_n_values = _get_values(tile_cfg, "warp_tile_n", [32])
config_set.warp_tile_k_values = _get_values(tile_cfg, "warp_tile_k", [16])
# Trait config
trait_cfg = data.get("trait_config", {})
config_set.pipeline_values = _get_values(trait_cfg, "pipeline", ["compv4"])
config_set.scheduler_values = _get_values(trait_cfg, "scheduler", ["intrawave"])
config_set.epilogue_values = _get_values(trait_cfg, "epilogue", ["cshuffle"])
config_set.pad_m_values = _get_values(trait_cfg, "pad_m", [False])
config_set.pad_n_values = _get_values(trait_cfg, "pad_n", [False])
config_set.pad_k_values = _get_values(trait_cfg, "pad_k", [False])
return config_set
# =============================================================================
# Convolution Configuration Classes
# =============================================================================
@dataclass
class ConvTileConfig:
"""Tile configuration for a convolution kernel"""
tile_m: int = 128 # M dimension (N * spatial_out for fwd)
tile_n: int = 128 # N dimension (K output channels for fwd)
tile_k: int = 32 # K dimension (C * filter for fwd)
warp_m: int = 2
warp_n: int = 2
warp_k: int = 1
warp_tile_m: int = 32
warp_tile_n: int = 32
warp_tile_k: int = 16
@dataclass
class ConvTraitConfig:
"""Trait configuration for a convolution kernel"""
pipeline: str = "compv3"
scheduler: str = "intrawave"
epilogue: str = "cshuffle"
pad_m: bool = True
pad_n: bool = True
pad_k: bool = True
double_smem_buffer: bool = False
num_groups_to_merge: int = 1
@dataclass
class ConvKernelConfig:
"""Complete convolution kernel configuration"""
tile: ConvTileConfig = field(default_factory=ConvTileConfig)
trait: ConvTraitConfig = field(default_factory=ConvTraitConfig)
dtype_input: str = "fp16"
dtype_weight: str = "fp16"
dtype_output: str = "fp16"
dtype_acc: str = "fp32"
variant: str = "forward" # forward, bwd_data, bwd_weight
ndim: int = 2 # 1, 2, or 3
layout: str = "nhwgc"
gpu_target: str = "gfx942"
# Vector sizes
vector_size_a: int = 4
vector_size_b: int = 8
vector_size_c: int = 8
# Occupancy
block_per_cu: int = 1
num_wave_groups: int = 1
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for codegen"""
return {
"tile_m": self.tile.tile_m,
"tile_n": self.tile.tile_n,
"tile_k": self.tile.tile_k,
"warp_m": self.tile.warp_m,
"warp_n": self.tile.warp_n,
"warp_k": self.tile.warp_k,
"warp_tile_m": self.tile.warp_tile_m,
"warp_tile_n": self.tile.warp_tile_n,
"warp_tile_k": self.tile.warp_tile_k,
"pipeline": self.trait.pipeline,
"scheduler": self.trait.scheduler,
"epilogue": self.trait.epilogue,
"pad_m": self.trait.pad_m,
"pad_n": self.trait.pad_n,
"pad_k": self.trait.pad_k,
"double_smem_buffer": self.trait.double_smem_buffer,
"num_groups_to_merge": self.trait.num_groups_to_merge,
"dtype_input": self.dtype_input,
"dtype_weight": self.dtype_weight,
"dtype_output": self.dtype_output,
"dtype_acc": self.dtype_acc,
"variant": self.variant,
"ndim": self.ndim,
"layout": self.layout,
"gpu_target": self.gpu_target,
"vector_size_a": self.vector_size_a,
"vector_size_b": self.vector_size_b,
"vector_size_c": self.vector_size_c,
"block_per_cu": self.block_per_cu,
"num_wave_groups": self.num_wave_groups,
}
def kernel_name(self) -> str:
"""Generate kernel name from config"""
variant_map = {"forward": "fwd", "bwd_data": "bwdd", "bwd_weight": "bwdw"}
var_str = variant_map.get(self.variant, self.variant)
name = f"conv_{var_str}_{self.dtype_input}_{self.ndim}d"
name += f"_{self.trait.pipeline}_{self.trait.epilogue}_{self.trait.scheduler}"
name += f"_{self.tile.tile_m}x{self.tile.tile_n}x{self.tile.tile_k}"
name += f"_{self.tile.warp_m}x{self.tile.warp_n}x{self.tile.warp_k}"
name += (
f"_{self.tile.warp_tile_m}x{self.tile.warp_tile_n}x{self.tile.warp_tile_k}"
)
return name
@dataclass
class ConvKernelConfigSet:
"""A set of convolution kernel configurations loaded from JSON"""
name: str = "default"
configs: List[ConvKernelConfig] = field(default_factory=list)
# Tile parameter ranges
tile_m_values: List[int] = field(default_factory=lambda: [128])
tile_n_values: List[int] = field(default_factory=lambda: [128])
tile_k_values: List[int] = field(default_factory=lambda: [32])
warp_m_values: List[int] = field(default_factory=lambda: [2])
warp_n_values: List[int] = field(default_factory=lambda: [2])
warp_k_values: List[int] = field(default_factory=lambda: [1])
warp_tile_m_values: List[int] = field(default_factory=lambda: [32])
warp_tile_n_values: List[int] = field(default_factory=lambda: [32])
warp_tile_k_values: List[int] = field(default_factory=lambda: [16])
# Trait parameter ranges
pipeline_values: List[str] = field(default_factory=lambda: ["compv3"])
scheduler_values: List[str] = field(default_factory=lambda: ["intrawave"])
epilogue_values: List[str] = field(default_factory=lambda: ["cshuffle"])
pad_m_values: List[bool] = field(default_factory=lambda: [True])
pad_n_values: List[bool] = field(default_factory=lambda: [True])
pad_k_values: List[bool] = field(default_factory=lambda: [True])
double_smem_buffer_values: List[bool] = field(default_factory=lambda: [False])
num_groups_to_merge_values: List[int] = field(default_factory=lambda: [1])
# Vector sizes
vector_size_a_values: List[int] = field(default_factory=lambda: [4])
vector_size_b_values: List[int] = field(default_factory=lambda: [8])
vector_size_c_values: List[int] = field(default_factory=lambda: [8])
# Occupancy
block_per_cu_values: List[int] = field(default_factory=lambda: [1])
num_wave_groups_values: List[int] = field(default_factory=lambda: [1])
# Data types
dtype_input: str = "fp16"
dtype_weight: str = "fp16"
dtype_output: str = "fp16"
dtype_acc: str = "fp32"
# Conv specific
variant: str = "forward"
ndim: int = 2
layout: str = "nhwgc"
gpu_targets: List[str] = field(default_factory=lambda: ["gfx942"])
def generate_configs(self) -> Iterator[ConvKernelConfig]:
"""Generate all kernel configurations (cartesian product)"""
# Tile parameters
tile_params = itertools.product(
self.tile_m_values,
self.tile_n_values,
self.tile_k_values,
self.warp_m_values,
self.warp_n_values,
self.warp_k_values,
self.warp_tile_m_values,
self.warp_tile_n_values,
self.warp_tile_k_values,
)
# Trait parameters
trait_params = itertools.product(
self.pipeline_values,
self.scheduler_values,
self.epilogue_values,
self.pad_m_values,
self.pad_n_values,
self.pad_k_values,
self.double_smem_buffer_values,
self.num_groups_to_merge_values,
)
# Vector/occupancy parameters
extra_params = itertools.product(
self.vector_size_a_values,
self.vector_size_b_values,
self.vector_size_c_values,
self.block_per_cu_values,
self.num_wave_groups_values,
)
# Convert to lists for reuse
tile_list = list(tile_params)
trait_list = list(trait_params)
extra_list = list(extra_params)
# Generate for each GPU target
for gpu_target in self.gpu_targets:
for tile in tile_list:
for trait in trait_list:
for extra in extra_list:
tile_cfg = ConvTileConfig(
tile_m=tile[0],
tile_n=tile[1],
tile_k=tile[2],
warp_m=tile[3],
warp_n=tile[4],
warp_k=tile[5],
warp_tile_m=tile[6],
warp_tile_n=tile[7],
warp_tile_k=tile[8],
)
trait_cfg = ConvTraitConfig(
pipeline=trait[0],
scheduler=trait[1],
epilogue=trait[2],
pad_m=trait[3],
pad_n=trait[4],
pad_k=trait[5],
double_smem_buffer=trait[6],
num_groups_to_merge=trait[7],
)
yield ConvKernelConfig(
tile=tile_cfg,
trait=trait_cfg,
dtype_input=self.dtype_input,
dtype_weight=self.dtype_weight,
dtype_output=self.dtype_output,
dtype_acc=self.dtype_acc,
variant=self.variant,
ndim=self.ndim,
layout=self.layout,
gpu_target=gpu_target,
vector_size_a=extra[0],
vector_size_b=extra[1],
vector_size_c=extra[2],
block_per_cu=extra[3],
num_wave_groups=extra[4],
)
def config_count(self) -> int:
"""Get total number of configurations"""
tile_count = (
len(self.tile_m_values)
* len(self.tile_n_values)
* len(self.tile_k_values)
* len(self.warp_m_values)
* len(self.warp_n_values)
* len(self.warp_k_values)
* len(self.warp_tile_m_values)
* len(self.warp_tile_n_values)
* len(self.warp_tile_k_values)
)
trait_count = (
len(self.pipeline_values)
* len(self.scheduler_values)
* len(self.epilogue_values)
* len(self.pad_m_values)
* len(self.pad_n_values)
* len(self.pad_k_values)
* len(self.double_smem_buffer_values)
* len(self.num_groups_to_merge_values)
)
extra_count = (
len(self.vector_size_a_values)
* len(self.vector_size_b_values)
* len(self.vector_size_c_values)
* len(self.block_per_cu_values)
* len(self.num_wave_groups_values)
)
return tile_count * trait_count * extra_count * len(self.gpu_targets)
def load_conv_kernel_configs(json_path: str | Path) -> ConvKernelConfigSet:
"""
Load convolution kernel configurations from a JSON file.
Args:
json_path: Path to JSON configuration file
Returns:
ConvKernelConfigSet with all parameter values loaded
"""
json_path = Path(json_path)
with open(json_path) as f:
data = json.load(f)
config_set = ConvKernelConfigSet()
# Name
config_set.name = data.get("kernel_set_name", json_path.stem)
# Data types
if "datatype" in data:
dt = data["datatype"]
config_set.dtype_input = dt.get("input", "fp16")
config_set.dtype_weight = dt.get("weight", "fp16")
config_set.dtype_output = dt.get("output", "fp16")
config_set.dtype_acc = dt.get("acc", "fp32")
# Conv specific
config_set.variant = data.get("variant", "forward")
config_set.ndim = data.get("ndim", 2)
config_set.layout = data.get("layout", "nhwgc")
# GPU targets
if "gpu_targets" in data:
config_set.gpu_targets = data["gpu_targets"]
elif "gpu_target" in data:
config_set.gpu_targets = [data["gpu_target"]]
# Tile config
tile_cfg = data.get("tile_config", {})
config_set.tile_m_values = _get_values(tile_cfg, "tile_m", [128])
config_set.tile_n_values = _get_values(tile_cfg, "tile_n", [128])
config_set.tile_k_values = _get_values(tile_cfg, "tile_k", [32])
config_set.warp_m_values = _get_values(tile_cfg, "warp_m", [2])
config_set.warp_n_values = _get_values(tile_cfg, "warp_n", [2])
config_set.warp_k_values = _get_values(tile_cfg, "warp_k", [1])
config_set.warp_tile_m_values = _get_values(tile_cfg, "warp_tile_m", [32])
config_set.warp_tile_n_values = _get_values(tile_cfg, "warp_tile_n", [32])
config_set.warp_tile_k_values = _get_values(tile_cfg, "warp_tile_k", [16])
# Trait config
trait_cfg = data.get("trait_config", {})
config_set.pipeline_values = _get_values(trait_cfg, "pipeline", ["compv3"])
config_set.scheduler_values = _get_values(trait_cfg, "scheduler", ["intrawave"])
config_set.epilogue_values = _get_values(trait_cfg, "epilogue", ["cshuffle"])
config_set.pad_m_values = _get_values(trait_cfg, "pad_m", [True])
config_set.pad_n_values = _get_values(trait_cfg, "pad_n", [True])
config_set.pad_k_values = _get_values(trait_cfg, "pad_k", [True])
config_set.double_smem_buffer_values = _get_values(
trait_cfg, "double_smem_buffer", [False]
)
config_set.num_groups_to_merge_values = _get_values(
trait_cfg, "num_groups_to_merge", [1]
)
# Vector config
vec_cfg = data.get("vector_config", {})
config_set.vector_size_a_values = _get_values(vec_cfg, "vector_size_a", [4])
config_set.vector_size_b_values = _get_values(vec_cfg, "vector_size_b", [8])
config_set.vector_size_c_values = _get_values(vec_cfg, "vector_size_c", [8])
# Occupancy config
occ_cfg = data.get("occupancy_config", {})
config_set.block_per_cu_values = _get_values(occ_cfg, "block_per_cu", [1])
config_set.num_wave_groups_values = _get_values(occ_cfg, "num_wave_groups", [1])
return config_set
def generate_cpp_conv_kernel_set_declaration(
config_set: ConvKernelConfigSet,
set_name: Optional[str] = None,
) -> str:
"""
Generate C++ DECL_CONV_KERNEL_SET code from a ConvKernelConfigSet.
"""
name = set_name or config_set.name
lines = [f"DECL_CONV_KERNEL_SET({name},"]
for config in config_set.generate_configs():
line = f' .add("{config.dtype_input}", "{config.variant}", {config.ndim}, '
line += f"{config.tile.tile_m}, {config.tile.tile_n}, {config.tile.tile_k})"
lines.append(line)
lines.append(");")
return "\n".join(lines)
# =============================================================================
# GEMM Configuration Export Functions
# =============================================================================
def generate_cpp_kernel_set_declaration(
config_set: KernelConfigSet,
set_name: Optional[str] = None,
) -> str:
"""
Generate C++ DECL_KERNEL_SET code from a KernelConfigSet.
Args:
config_set: The kernel configuration set
set_name: Optional name override for the kernel set
Returns:
C++ code string with DECL_KERNEL_SET declaration
"""
name = set_name or config_set.name
lines = [f"DECL_KERNEL_SET({name},"]
for config in config_set.generate_configs():
# Generate .add() call for each config
line = f' .add("{config.dtype_a}", "{config.layout}", '
line += f"{config.tile.tile_m}, {config.tile.tile_n}, {config.tile.tile_k})"
lines.append(line)
lines.append(");")
return "\n".join(lines)
# CLI for testing
if __name__ == "__main__":
import sys
if len(sys.argv) < 2:
print("Usage: python kernel_config_loader.py <config.json>")
print("\nLoads kernel configurations from JSON and prints summary.")
sys.exit(1)
json_path = sys.argv[1]
try:
config_set = load_kernel_configs(json_path)
print(f"Kernel Set: {config_set.name}")
print(
f"Data Types: A={config_set.dtype_a}, B={config_set.dtype_b}, C={config_set.dtype_c}, Acc={config_set.dtype_acc}"
)
print(f"Layout: {config_set.layout}")
print(f"GPU Targets: {config_set.gpu_targets}")
print(f"Variant: {config_set.variant}")
print()
print("Tile Configurations:")
print(f" tile_m: {config_set.tile_m_values}")
print(f" tile_n: {config_set.tile_n_values}")
print(f" tile_k: {config_set.tile_k_values}")
print(f" warp_m: {config_set.warp_m_values}")
print(f" warp_n: {config_set.warp_n_values}")
print(f" warp_k: {config_set.warp_k_values}")
print(
f" warp_tile: {config_set.warp_tile_m_values}x{config_set.warp_tile_n_values}x{config_set.warp_tile_k_values}"
)
print()
print("Trait Configurations:")
print(f" pipeline: {config_set.pipeline_values}")
print(f" scheduler: {config_set.scheduler_values}")
print(f" epilogue: {config_set.epilogue_values}")
print(
f" padding: m={config_set.pad_m_values}, n={config_set.pad_n_values}, k={config_set.pad_k_values}"
)
print()
print(f"Total configurations: {config_set.config_count()}")
print()
# Print first few config names
print("Sample kernel names:")
for i, config in enumerate(config_set.generate_configs()):
if i >= 5:
print(f" ... and {config_set.config_count() - 5} more")
break
print(f" {config.kernel_name()}")
print()
# Generate C++ code
if "--cpp" in sys.argv:
print("C++ Declaration:")
print("-" * 60)
print(generate_cpp_kernel_set_declaration(config_set))
except Exception as e:
print(f"Error: {e}")
sys.exit(1)

View File

@@ -0,0 +1,518 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Preselected, Benchmarked Kernel Configurations
Curated kernel sets optimized for different workload characteristics:
- Compute-friendly: Large tiles, high arithmetic intensity
- Memory-friendly: Smaller tiles, better memory access patterns
- Latency-friendly: Minimal tiles, low latency for small problems
"""
from functools import partial, lru_cache
from typing import List
from unified_gemm_codegen import KernelConfig, TileConfig, TraitConfig, GemmVariant
# ============================================================================
# Base Configurations
# ============================================================================
def _base_fp16_rcr_compute() -> partial:
"""Base configuration for compute-intensive FP16 RCR kernels"""
return partial(
KernelConfig,
tile=None, # Will be overridden
trait=TraitConfig(
pipeline="compv4",
epilogue="cshuffle",
scheduler="intrawave",
pad_m=True,
pad_n=True,
pad_k=True,
persistent=False,
),
variant=GemmVariant.STANDARD,
block_size=256,
k_block_per_cu=1,
num_wave_groups=1,
)
def _base_fp16_rcr_memory() -> partial:
"""Base configuration for memory-intensive FP16 RCR kernels"""
# Note: Use 'mem' pipeline for interwave scheduler (compv3/compv4/compv5/compv6 only support intrawave)
return partial(
KernelConfig,
tile=None, # Will be overridden
trait=TraitConfig(
pipeline="mem",
epilogue="cshuffle",
scheduler="interwave",
pad_m=True,
pad_n=True,
pad_k=True,
persistent=False,
),
variant=GemmVariant.STANDARD,
block_size=128,
k_block_per_cu=1,
num_wave_groups=1,
)
def _base_fp16_rcr_latency() -> partial:
"""Base configuration for latency-sensitive FP16 RCR kernels"""
return partial(
KernelConfig,
tile=None, # Will be overridden
trait=TraitConfig(
pipeline="mem",
epilogue="default",
scheduler="intrawave",
pad_m=True,
pad_n=True,
pad_k=True,
persistent=False,
),
variant=GemmVariant.STANDARD,
block_size=128,
k_block_per_cu=1,
num_wave_groups=1,
)
# ============================================================================
# Preselected FP16 RCR Kernels
# ============================================================================
@lru_cache(None)
def preselected_fp16_rcr_compute() -> List[KernelConfig]:
"""
Compute-friendly FP16 RCR kernels
Optimized for:
- Large M, N dimensions (>= 128)
- High arithmetic intensity
- Good occupancy
- Maximum throughput
"""
base = _base_fp16_rcr_compute()
return [
# Large tiles for maximum compute
base(tile=TileConfig(256, 256, 32, 4, 4, 1, 32, 32, 16)),
base(tile=TileConfig(256, 256, 64, 4, 4, 1, 32, 32, 16)),
base(tile=TileConfig(256, 128, 32, 4, 2, 1, 32, 32, 16)),
base(tile=TileConfig(128, 256, 32, 2, 4, 1, 32, 32, 16)),
# Balanced tiles
base(tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16)),
base(tile=TileConfig(128, 128, 64, 2, 2, 1, 32, 32, 16)),
# With persistent kernel for large batches
base(
tile=TileConfig(256, 256, 32, 4, 4, 1, 32, 32, 16),
trait=TraitConfig(
pipeline="compv4",
epilogue="cshuffle",
scheduler="intrawave",
pad_m=False,
pad_n=False,
pad_k=False,
persistent=True,
),
),
]
@lru_cache(None)
def preselected_fp16_rcr_memory() -> List[KernelConfig]:
"""
Memory-friendly FP16 RCR kernels
Optimized for:
- Small to medium M, N dimensions
- Memory-bound workloads
- Better cache utilization
- Lower register pressure
"""
base = _base_fp16_rcr_memory()
return [
# Small tiles for memory efficiency
base(tile=TileConfig(16, 32, 32, 1, 1, 1, 16, 16, 16)),
base(tile=TileConfig(32, 16, 32, 1, 1, 1, 16, 16, 16)),
base(tile=TileConfig(16, 64, 32, 1, 2, 1, 16, 16, 16)),
base(tile=TileConfig(64, 16, 32, 2, 1, 1, 16, 16, 16)),
# Medium tiles
base(tile=TileConfig(32, 64, 32, 1, 1, 1, 32, 32, 16)),
base(tile=TileConfig(64, 32, 32, 1, 1, 1, 32, 32, 16)),
base(tile=TileConfig(32, 128, 32, 1, 2, 1, 32, 32, 16)),
base(tile=TileConfig(128, 32, 32, 2, 1, 1, 32, 32, 16)),
]
@lru_cache(None)
def preselected_fp16_rcr_latency() -> List[KernelConfig]:
"""
Latency-friendly FP16 RCR kernels
Optimized for:
- Very small M, N dimensions (< 64)
- Minimal launch overhead
- Low latency
- Quick execution
"""
base = _base_fp16_rcr_latency()
return [
# Minimal tiles for low latency
base(tile=TileConfig(16, 32, 32, 1, 1, 1, 16, 16, 16)),
base(tile=TileConfig(32, 16, 32, 1, 1, 1, 16, 16, 16)),
]
# ============================================================================
# Preselected Multi-D Kernels
# ============================================================================
@lru_cache(None)
def preselected_fp16_rcr_multi_d() -> List[KernelConfig]:
"""
Multi-D GEMM kernels with element-wise fusion
Common fusions:
- MultiDAdd: E = C + D0 + D1
- Relu: E = max(C, 0)
- Gelu: E = gelu(C)
"""
base = _base_fp16_rcr_compute()
configs = []
# Best-performing tile for fused operations
tile = TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16)
# Common element-wise operations
for ew_op in ["MultiDAdd", "Relu", "Gelu", "FastGelu"]:
for num_d in [1, 2]:
configs.append(
base(
tile=tile,
variant=GemmVariant.MULTI_D,
elementwise_op=ew_op,
num_d_tensors=num_d,
)
)
return configs
@lru_cache(None)
def preselected_fp16_rcr_preshuffle() -> List[KernelConfig]:
"""
Preshuffle GEMM kernels for weight optimization
Best for:
- Repeated use of same weights
- Inference workloads
- Batch size > 1
"""
base = _base_fp16_rcr_compute()
return [
base(
tile=TileConfig(256, 256, 32, 4, 4, 1, 32, 32, 16),
variant=GemmVariant.PRESHUFFLE,
preshuffle=True,
),
base(
tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16),
variant=GemmVariant.PRESHUFFLE,
preshuffle=True,
),
]
# ============================================================================
# Unified Preselected Sets
# ============================================================================
@lru_cache(None)
def preselected_fp16_rcr_all() -> List[KernelConfig]:
"""All preselected FP16 RCR kernels"""
return (
preselected_fp16_rcr_compute()
+ preselected_fp16_rcr_memory()
+ preselected_fp16_rcr_latency()
+ preselected_fp16_rcr_multi_d()
+ preselected_fp16_rcr_preshuffle()
)
@lru_cache(None)
def preselected_fp16_rcr_essential() -> List[KernelConfig]:
"""
Essential FP16 RCR kernels - minimal set for most workloads
Covers:
- 90% of common GEMM sizes
- Key fusion operations
- Balanced performance
"""
base_compute = _base_fp16_rcr_compute()
base_memory = _base_fp16_rcr_memory()
return [
# Top compute kernels
base_compute(tile=TileConfig(256, 256, 32, 4, 4, 1, 32, 32, 16)),
base_compute(tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16)),
# Top memory kernels
base_memory(tile=TileConfig(32, 64, 32, 1, 1, 1, 32, 32, 16)),
base_memory(tile=TileConfig(64, 32, 32, 1, 1, 1, 32, 32, 16)),
# Essential fusions
base_compute(
tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16),
variant=GemmVariant.MULTI_D,
elementwise_op="Relu",
num_d_tensors=1,
),
base_compute(
tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16),
variant=GemmVariant.MULTI_D,
elementwise_op="Gelu",
num_d_tensors=1,
),
]
# ============================================================================
# Default Fallback
# ============================================================================
def default_kernel() -> KernelConfig:
"""
Default fallback kernel - guaranteed to work
Known-good configuration tested on gfx942
"""
return KernelConfig(
tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16),
trait=TraitConfig(
pipeline="compv4",
epilogue="cshuffle",
scheduler="intrawave",
pad_m=True,
pad_n=True,
pad_k=True,
persistent=False,
),
variant=GemmVariant.STANDARD,
block_size=256,
k_block_per_cu=1,
num_wave_groups=1,
)
# ============================================================================
# BF16 Preselected Sets
# ============================================================================
@lru_cache(None)
def preselected_bf16_rcr_essential() -> List[KernelConfig]:
"""Essential BF16 RCR kernels"""
base_compute = partial(
KernelConfig,
tile=None,
trait=TraitConfig(
pipeline="compv4",
epilogue="cshuffle",
scheduler="intrawave",
pad_m=True,
pad_n=True,
pad_k=True,
persistent=False,
),
variant=GemmVariant.STANDARD,
block_size=256,
)
return [
base_compute(tile=TileConfig(256, 256, 32, 4, 4, 1, 32, 32, 16)),
base_compute(tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16)),
]
# ============================================================================
# INT8 Preselected Sets
# ============================================================================
@lru_cache(None)
def preselected_int8_rcr_essential() -> List[KernelConfig]:
"""Essential INT8 RCR kernels for quantized inference"""
base = partial(
KernelConfig,
tile=None,
trait=TraitConfig(
pipeline="compv4",
epilogue="cshuffle",
scheduler="intrawave",
pad_m=True,
pad_n=True,
pad_k=True,
persistent=False,
),
variant=GemmVariant.STANDARD,
block_size=256,
)
return [
base(tile=TileConfig(256, 256, 64, 4, 4, 1, 32, 32, 16)),
base(tile=TileConfig(128, 128, 64, 2, 2, 1, 32, 32, 16)),
]
# ============================================================================
# FP8 Preselected Sets
# ============================================================================
@lru_cache(None)
def preselected_fp8_rcr_essential() -> List[KernelConfig]:
"""Essential FP8 RCR kernels for AI training"""
base = partial(
KernelConfig,
tile=None,
trait=TraitConfig(
pipeline="compv4",
epilogue="cshuffle",
scheduler="intrawave",
pad_m=True,
pad_n=True,
pad_k=True,
persistent=False,
),
variant=GemmVariant.STANDARD,
block_size=256,
)
return [
base(tile=TileConfig(256, 256, 64, 4, 4, 1, 32, 32, 16)),
base(tile=TileConfig(128, 128, 64, 2, 2, 1, 32, 32, 16)),
]
# ============================================================================
# Mixed Precision Preselected Sets
# ============================================================================
@lru_cache(None)
def preselected_mixed_precision() -> List[KernelConfig]:
"""Mixed-precision kernels (FP16 inputs, FP32 output)"""
base = partial(
KernelConfig,
tile=None,
trait=TraitConfig(
pipeline="compv4",
epilogue="cshuffle",
scheduler="intrawave",
pad_m=True,
pad_n=True,
pad_k=True,
persistent=False,
),
variant=GemmVariant.STANDARD,
block_size=256,
)
return [
base(tile=TileConfig(256, 256, 32, 4, 4, 1, 32, 32, 16)),
base(tile=TileConfig(128, 128, 32, 2, 2, 1, 32, 32, 16)),
]
# ============================================================================
# Registry
# ============================================================================
PRESELECTED_SETS = {
# FP16 sets
"fp16_rcr_compute": preselected_fp16_rcr_compute,
"fp16_rcr_memory": preselected_fp16_rcr_memory,
"fp16_rcr_latency": preselected_fp16_rcr_latency,
"fp16_rcr_multi_d": preselected_fp16_rcr_multi_d,
"fp16_rcr_preshuffle": preselected_fp16_rcr_preshuffle,
"fp16_rcr_all": preselected_fp16_rcr_all,
"fp16_rcr_essential": preselected_fp16_rcr_essential,
# BF16 sets
"bf16_rcr_essential": preselected_bf16_rcr_essential,
# INT8 sets
"int8_rcr_essential": preselected_int8_rcr_essential,
# FP8 sets
"fp8_rcr_essential": preselected_fp8_rcr_essential,
# Mixed precision
"mixed_precision": preselected_mixed_precision,
}
def get_preselected_set(name: str) -> List[KernelConfig]:
"""Get a preselected kernel set by name"""
if name not in PRESELECTED_SETS:
raise ValueError(
f"Unknown preselected set: {name}. Available: {list(PRESELECTED_SETS.keys())}"
)
return PRESELECTED_SETS[name]()
def list_preselected_sets() -> List[str]:
"""List all available preselected sets"""
return list(PRESELECTED_SETS.keys())
# ============================================================================
# CLI for testing
# ============================================================================
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(
description="List preselected kernel configurations"
)
parser.add_argument(
"--set",
type=str,
default="fp16_rcr_essential",
choices=list_preselected_sets(),
help="Preselected set to display",
)
parser.add_argument("--count-only", action="store_true", help="Only show count")
args = parser.parse_args()
configs = get_preselected_set(args.set)
if args.count_only:
print(f"{args.set}: {len(configs)} kernels")
else:
print(f"Preselected set: {args.set}")
print(f"Total kernels: {len(configs)}\n")
for i, cfg in enumerate(configs, 1):
print(f"{i}. {cfg.variant.value}")
print(f" Tile: {cfg.tile.tile_m}x{cfg.tile.tile_n}x{cfg.tile.tile_k}")
print(f" Pipeline: {cfg.trait.pipeline}, Epilogue: {cfg.trait.epilogue}")
if cfg.variant == GemmVariant.MULTI_D:
print(
f" Element-wise: {cfg.elementwise_op}, D tensors: {cfg.num_d_tensors}"
)
print()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,448 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
cmake_minimum_required(VERSION 3.16)
# Get processor count for parallel builds
include(ProcessorCount)
ProcessorCount(NPROC)
if(NPROC EQUAL 0)
set(NPROC 4)
endif()
# GPU target architecture (passed from command line or default to gfx942)
if(NOT DEFINED GPU_TARGETS OR GPU_TARGETS STREQUAL "")
set(GPU_TARGETS "gfx942" CACHE STRING "GPU architecture target")
endif()
# Extract first target if multiple are provided (we only support single target builds)
string(REPLACE ";" " " GPU_TARGETS_SPACE "${GPU_TARGETS}")
string(REPLACE " " ";" GPU_TARGETS_LIST "${GPU_TARGETS_SPACE}")
list(GET GPU_TARGETS_LIST 0 GPU_TARGET)
message(STATUS "Building for GPU target: ${GPU_TARGET}")
# NOTE: Per-kernel compilation is now automatic via declarative examples
# Each example generates only its declared kernels (from DECL_KERNEL_SET)
# Link to dispatcher library
link_directories(${CMAKE_CURRENT_SOURCE_DIR}/../build)
# =============================================================================
# Kernel Output Directory
# =============================================================================
set(KERNEL_OUTPUT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../build/generated_kernels")
file(MAKE_DIRECTORY ${KERNEL_OUTPUT_DIR})
# =============================================================================
# Kernel Generation Targets (run during 'make', not 'cmake')
# =============================================================================
# Sentinel files to track generation
set(GEMM_SENTINEL "${KERNEL_OUTPUT_DIR}/.gemm_generated")
# Generate GEMM kernels (standard + preshuffle + multi_d) - runs with internal parallelism
# Note: 4-char layout "rcrr" means A=row, B=col, C=row, D=row (for multi-d)
add_custom_command(
OUTPUT ${GEMM_SENTINEL}
COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../codegen/unified_gemm_codegen.py
--datatype fp16 --layout rcrr --variants standard preshuffle multi_d
--output ${KERNEL_OUTPUT_DIR}
COMMAND ${CMAKE_COMMAND} -E touch ${GEMM_SENTINEL}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../codegen
COMMENT "Generating GEMM kernels (fp16, rcrr, standard + preshuffle + multi_d) with internal parallelism..."
VERBATIM
)
add_custom_target(generate_gemm_kernels
DEPENDS ${GEMM_SENTINEL}
COMMENT "GEMM kernel generation target"
)
# Alias for generate_all_kernels (GEMM only now)
add_custom_target(generate_all_kernels
DEPENDS generate_gemm_kernels
)
# =============================================================================
# Per-Kernel Compilation (Maximum Parallelism)
# =============================================================================
# Enable with: cmake -DPER_KERNEL_COMPILATION=ON
#
# This creates ONE translation unit per kernel, enabling:
# 1. Maximum parallelism with make -j$(nproc)
# 2. Per-kernel build progress: "[1/128] Building kernel: gemm_fp16_128x128"
# 3. Incremental rebuilds (only changed kernels recompile)
# 4. Fine-grained build time analysis
#
# Build process:
# 1. Generate kernel headers (.hpp)
# 2. Generate wrapper files (.cpp) - one per kernel
# 3. Compile each wrapper in parallel
# 4. Link all objects into libdispatcher_kernels.so
#
# Example output:
# [ 1/128] Building kernel: gemm_fp16_rcr_128x128x32
# [ 2/128] Building kernel: gemm_fp16_rcr_256x256x64
# ...
# [128/128] Linking: libdispatcher_kernels.so
# =============================================================================
set(WRAPPER_DIR "${CMAKE_BINARY_DIR}/kernel_wrappers")
set(WRAPPER_SENTINEL "${WRAPPER_DIR}/.wrappers_generated")
# Target: Generate wrapper .cpp files (one per kernel)
add_custom_command(
OUTPUT ${WRAPPER_SENTINEL}
COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../codegen/generate_kernel_wrappers.py
--kernel-dir ${KERNEL_OUTPUT_DIR}
--output-dir ${WRAPPER_DIR}
--generate-makefile
--generate-cmake
COMMAND ${CMAKE_COMMAND} -E touch ${WRAPPER_SENTINEL}
DEPENDS ${GEMM_SENTINEL}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../codegen
COMMENT "Generating per-kernel wrapper .cpp files..."
VERBATIM
)
add_custom_target(generate_kernel_wrappers
DEPENDS ${WRAPPER_SENTINEL}
COMMENT "Kernel wrapper generation target"
)
# Target: Build kernels using generated Makefile (true per-kernel progress)
add_custom_target(build_kernels_parallel
COMMAND ${CMAKE_COMMAND} -E echo "Building kernels with per-kernel progress..."
COMMAND make -C ${WRAPPER_DIR} -j${NPROC} 2>&1 | grep -E "^\\[|Built|Linking|Error"
DEPENDS generate_kernel_wrappers
WORKING_DIRECTORY ${WRAPPER_DIR}
COMMENT "Compiling kernels in parallel (one translation unit per kernel)..."
VERBATIM
)
# Global kernel build (optional - prefer per-example builds for minimal compilation)
# This builds ALL kernels into a shared library - use for Python bindings or full library
# For C++ examples, use declarative approach which builds only needed kernels
add_custom_target(dispatcher_kernels
COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../scripts/parallel_kernel_builder.py
--kernel-dir ${KERNEL_OUTPUT_DIR}
--output-dir ${CMAKE_BINARY_DIR}
--include-dirs "${CMAKE_CURRENT_SOURCE_DIR}/../../include,${CMAKE_CURRENT_SOURCE_DIR}/../include"
--jobs ${NPROC}
DEPENDS generate_all_kernels
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../scripts
COMMENT "Building ALL kernels in parallel (prefer per-example builds for minimal compilation)..."
VERBATIM
)
# =============================================================================
# Force regeneration targets (useful when you want to regenerate)
# =============================================================================
add_custom_target(regenerate_gemm_kernels
COMMAND ${CMAKE_COMMAND} -E remove -f ${GEMM_SENTINEL}
COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../codegen/unified_gemm_codegen.py
--datatype fp16 --layout rcr --variants standard preshuffle multi_d
--output ${KERNEL_OUTPUT_DIR}
COMMAND ${CMAKE_COMMAND} -E touch ${GEMM_SENTINEL}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../codegen
COMMENT "Force regenerating GEMM kernels (standard + preshuffle + multi_d)..."
VERBATIM
)
add_custom_target(regenerate_all_kernels
DEPENDS regenerate_gemm_kernels
)
# Clean all per-example kernel directories
add_custom_target(clean_example_kernels
COMMAND ${CMAKE_COMMAND} -E echo "Removing per-example kernel directories..."
COMMAND find ${CMAKE_BINARY_DIR} -maxdepth 1 -type d -name "*_kernels" -exec rm -rf {} +
COMMENT "Cleaning all per-example kernel directories..."
VERBATIM
)
# =============================================================================
# Helper function to add a GPU example with force-included kernel
# =============================================================================
# Helper for GPU examples that use the dispatcher registry
# KERNEL_HEADER can be:
# - A registration header (register_all_kernels.hpp) - included directly in source
# - A specific kernel header - force-included via compiler flag
function(add_gpu_example NAME SOURCE KERNEL_HEADER)
add_executable(${NAME} ${SOURCE})
target_link_libraries(${NAME} PRIVATE ck_tile_dispatcher)
target_include_directories(${NAME} PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/../../include # CK root include
${CMAKE_CURRENT_SOURCE_DIR}/../include # Dispatcher include
${CMAKE_CURRENT_SOURCE_DIR}/../build/generated_kernels # Generated kernels
${CMAKE_CURRENT_SOURCE_DIR}/../build/generated_kernels/dispatcher_wrappers # Wrapper headers
)
# Check if using registration header (no force-include needed)
get_filename_component(HEADER_NAME ${KERNEL_HEADER} NAME)
if(HEADER_NAME STREQUAL "register_all_kernels.hpp")
# Registration header - examples include it directly
target_compile_options(${NAME} PRIVATE
-DGEMM_KERNEL_AVAILABLE=1
-mllvm -enable-noalias-to-md-conversion=0
-Wno-undefined-func-template
-Wno-float-equal
--offload-compress
)
else()
# Specific kernel header - force-include it
target_compile_options(${NAME} PRIVATE
-include ${KERNEL_HEADER}
-mllvm -enable-noalias-to-md-conversion=0
-Wno-undefined-func-template
-Wno-float-equal
--offload-compress
)
endif()
if(hip_FOUND)
target_link_libraries(${NAME} PRIVATE hip::device hip::host)
endif()
endfunction()
# Helper for standalone GPU examples (instantiate kernel directly, no pre-generated header)
function(add_standalone_gpu_example NAME SOURCE)
add_executable(${NAME} ${SOURCE})
target_link_libraries(${NAME} PRIVATE ck_tile_dispatcher)
target_include_directories(${NAME} PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/../../include # CK root include
${CMAKE_CURRENT_SOURCE_DIR}/../include # Dispatcher include
${CMAKE_CURRENT_SOURCE_DIR}/../build/generated_kernels # Generated kernels (optional)
)
target_compile_options(${NAME} PRIVATE
-mllvm -enable-noalias-to-md-conversion=0
-Wno-undefined-func-template
-Wno-float-equal
--offload-compress
)
if(hip_FOUND)
target_link_libraries(${NAME} PRIVATE hip::device hip::host)
endif()
endfunction()
# Helper for declarative examples (configuration demo, still needs HIP compiler for CK headers)
function(add_declarative_example NAME SOURCE)
add_executable(${NAME} ${SOURCE})
target_include_directories(${NAME} PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/../../include
${CMAKE_CURRENT_SOURCE_DIR}/../include
)
target_compile_options(${NAME} PRIVATE
-Wno-float-equal
-Wno-unused-variable
-Wno-undefined-func-template
-mllvm -enable-noalias-to-md-conversion=0
)
target_link_libraries(${NAME} PRIVATE ck_tile_dispatcher)
if(hip_FOUND)
target_link_libraries(${NAME} PRIVATE hip::device hip::host)
endif()
endfunction()
# =============================================================================
# GEMM Examples
# =============================================================================
# Per-example kernel directories are created from DECL_KERNEL_SET declarations
# Each example gets its own: build/<name>_kernels/
# This prevents clashes during parallel compilation of multiple examples.
# Helper function to add example with declarative kernel support
# Parses DECL_KERNEL_SET from source and generates ONLY the declared kernels
# This enables minimal builds: only kernels needed by this example are generated
#
# Key features:
# - Per-example kernel directories: build/<name>_kernels/ (no clashes)
# - Automatic header inclusion: No hardcoded #include needed in source
# - Minimal builds: Only declared kernels are generated
# - Auto-regeneration: Kernels regenerated if directory missing
# - Parallel compilation: Each kernel is a separate translation unit
function(add_declarative_gpu_example NAME SOURCE)
set(EXAMPLE_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/${SOURCE}")
get_filename_component(EXAMPLE_STEM ${SOURCE} NAME_WE)
# Per-example kernel directories
set(EXAMPLE_KERNEL_DIR "${CMAKE_BINARY_DIR}/${NAME}_kernels")
set(EXAMPLE_HEADER "${EXAMPLE_KERNEL_DIR}/${EXAMPLE_STEM}_kernels.hpp")
set(EXAMPLE_LIB "${EXAMPLE_KERNEL_DIR}/lib${NAME}_kernels.a")
set(EXAMPLE_SENTINEL "${EXAMPLE_KERNEL_DIR}/.generated")
# Generate AND compile kernels in parallel at make time
# This avoids slow cmake and gets per-kernel progress
add_custom_command(
OUTPUT ${EXAMPLE_SENTINEL} ${EXAMPLE_LIB}
COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../scripts/example_kernel_builder.py
${EXAMPLE_SOURCE}
--output-dir ${EXAMPLE_KERNEL_DIR}
--include-dirs "${CMAKE_CURRENT_SOURCE_DIR}/../../include,${CMAKE_CURRENT_SOURCE_DIR}/../include"
--gpu-target ${GPU_TARGET}
--jobs ${NPROC}
--target-name ${NAME}
COMMAND ${CMAKE_COMMAND} -E touch ${EXAMPLE_SENTINEL}
DEPENDS ${EXAMPLE_SOURCE}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../scripts
COMMENT "[${NAME}] Generating and compiling kernels from DECL_KERNEL_SET..."
VERBATIM
)
add_custom_target(generate_${NAME}_kernels DEPENDS ${EXAMPLE_SENTINEL})
# Add the executable
add_executable(${NAME} ${SOURCE})
target_link_libraries(${NAME} PRIVATE ck_tile_dispatcher)
# Link against the per-example kernel library
target_link_libraries(${NAME} PRIVATE ${EXAMPLE_LIB})
target_include_directories(${NAME} PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/../../include
${CMAKE_CURRENT_SOURCE_DIR}/../include
${EXAMPLE_KERNEL_DIR}
${EXAMPLE_KERNEL_DIR}/dispatcher_wrappers
)
# Force-include the generated registration header
target_compile_options(${NAME} PRIVATE
-include ${EXAMPLE_HEADER}
-DGEMM_KERNEL_AVAILABLE=1
-mllvm -enable-noalias-to-md-conversion=0
-Wno-undefined-func-template
-Wno-float-equal
--offload-compress
)
if(hip_FOUND)
target_link_libraries(${NAME} PRIVATE hip::device hip::host)
endif()
# Only depends on generating THIS example's kernels
add_dependencies(${NAME} generate_${NAME}_kernels)
endfunction()
# GEMM C++ examples with declarative kernel support
# Each example's C++ code contains DECL_KERNEL_SET which declares needed kernels
add_declarative_gpu_example(gemm_01_basic gemm/cpp/01_basic_gemm.cpp)
add_declarative_gpu_example(gemm_02_multi_size gemm/cpp/02_multi_size.cpp)
add_declarative_gpu_example(gemm_03_benchmark_validation gemm/cpp/03_benchmark_validation.cpp)
add_declarative_gpu_example(gemm_04_heuristics gemm/cpp/04_heuristics.cpp)
add_declarative_gpu_example(gemm_05_json_export gemm/cpp/05_json_export.cpp)
add_declarative_gpu_example(gemm_06_multi_registry gemm/cpp/06_multi_registry.cpp)
# =============================================================================
# GEMM Python Library - Single Fallback Kernel
# =============================================================================
# Generate a single fallback kernel for the Python library (fp16, rcr, compv4)
set(GEMM_FALLBACK_KERNEL_DIR "${CMAKE_CURRENT_BINARY_DIR}/gemm_python_fallback")
set(GEMM_FALLBACK_KERNEL "${GEMM_FALLBACK_KERNEL_DIR}/gemm_fp16_rcr_compv4_cshuffle_intrawave_False_False_False_False_128x128x32_2x2x1_32x32x16.hpp")
# Tile config JSON for single kernel generation
set(GEMM_FALLBACK_TILE_CONFIG "{\"tile_m\":[128],\"tile_n\":[128],\"tile_k\":[32],\"warp_m\":[2],\"warp_n\":[2],\"warp_k\":[1],\"warp_tile_m\":[32],\"warp_tile_n\":[32],\"warp_tile_k\":[16],\"pipeline\":[\"compv4\"],\"scheduler\":[\"intrawave\"],\"epilogue\":[\"cshuffle\"]}")
# Generate single fallback kernel (not all 6000+ kernels)
add_custom_command(
OUTPUT ${GEMM_FALLBACK_KERNEL}
COMMAND ${CMAKE_COMMAND} -E make_directory ${GEMM_FALLBACK_KERNEL_DIR}
COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../codegen/unified_gemm_codegen.py
--datatype fp16 --layout rcr --variants standard
--gpu-target ${GPU_TARGET}
--output-dir ${GEMM_FALLBACK_KERNEL_DIR}
--tile-config-json "${GEMM_FALLBACK_TILE_CONFIG}"
COMMENT "Generating single fallback GEMM kernel for Python library"
VERBATIM
)
add_custom_target(generate_gemm_fallback_kernel DEPENDS ${GEMM_FALLBACK_KERNEL})
# GEMM dynamic library for Python
add_library(dispatcher_gemm_lib SHARED ${CMAKE_CURRENT_SOURCE_DIR}/../bindings/ctypes/gemm_ctypes_lib.cpp)
target_link_libraries(dispatcher_gemm_lib PRIVATE ck_tile_dispatcher)
target_include_directories(dispatcher_gemm_lib PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/../../include
${CMAKE_CURRENT_SOURCE_DIR}/../include
${GEMM_FALLBACK_KERNEL_DIR}
)
target_compile_options(dispatcher_gemm_lib PRIVATE
-DCK_TILE_SINGLE_KERNEL_INCLUDE
-include ${GEMM_FALLBACK_KERNEL}
-DGFX_ARCH="${GPU_TARGET}"
-mllvm -enable-noalias-to-md-conversion=0
-Wno-undefined-func-template
-Wno-float-equal
--offload-compress
)
if(hip_FOUND)
target_link_libraries(dispatcher_gemm_lib PRIVATE hip::device hip::host)
endif()
add_dependencies(dispatcher_gemm_lib generate_gemm_fallback_kernel)
message(STATUS "GEMM examples configured - kernels will be generated during 'make'")
# Convenience target to build all Python ctypes libraries
add_custom_target(python_libs
DEPENDS dispatcher_gemm_lib
COMMENT "Building Python ctypes libraries (GEMM)"
)
# =============================================================================
# Per-Architecture Kernel Generation Targets
# =============================================================================
set(SUPPORTED_GPU_ARCHS gfx942 gfx90a gfx1100 gfx1030)
foreach(ARCH ${SUPPORTED_GPU_ARCHS})
# GEMM kernels for this arch
add_custom_target(generate_gemm_kernels_${ARCH}
COMMAND python3 ${CMAKE_CURRENT_SOURCE_DIR}/../codegen/unified_gemm_codegen.py
--datatype fp16 --layout rcr --gpu-target ${ARCH}
--output ${KERNEL_OUTPUT_DIR}
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../codegen
COMMENT "Generating GEMM kernels for ${ARCH}..."
VERBATIM
)
# Alias for kernels (GEMM only now)
add_custom_target(generate_kernels_${ARCH}
DEPENDS generate_gemm_kernels_${ARCH}
COMMENT "Generating all kernels for ${ARCH}..."
)
endforeach()
# =============================================================================
# Summary
# =============================================================================
message(STATUS "")
message(STATUS "=== Dispatcher Examples Configuration ===")
message(STATUS "")
message(STATUS "Kernels will be generated automatically during 'make'")
message(STATUS " Generated to: ${KERNEL_OUTPUT_DIR}")
message(STATUS "")
message(STATUS "Build targets:")
message(STATUS " make - Build all examples (generates kernels first)")
message(STATUS " make python_libs - Build Python ctypes libraries")
message(STATUS " make generate_all_kernels - Generate all kernels only")
message(STATUS " make regenerate_all_kernels - Force regenerate all kernels")
message(STATUS "")
message(STATUS "Per-architecture targets:")
message(STATUS " make generate_kernels_<arch> - Generate for specific arch")
message(STATUS " Supported archs: ${SUPPORTED_GPU_ARCHS}")
message(STATUS "")

View File

@@ -0,0 +1,210 @@
# CK Tile Dispatcher Examples
Comprehensive examples for GEMM operations with GPU execution.
> **Note**: Convolution examples have been moved to `ck-2/conv_archive/` for reference.
---
## Quick Start
### Step 1: Build
```bash
cd /path/to/composable_kernel/dispatcher
mkdir -p build && cd build
cmake .. \
-DCMAKE_PREFIX_PATH=/opt/rocm \
-DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-DCMAKE_BUILD_TYPE=Release \
-DGPU_TARGETS="gfx942" \
-DBUILD_DISPATCHER_EXAMPLES=ON
# Build everything (C++ examples + Python libraries)
make -j$(nproc)
# Or build ONLY Python libraries (faster)
make python_libs -j$(nproc)
```
### Step 2: Run C++ Examples
```bash
cd build/examples
# GEMM
./gemm_01_basic
./gemm_02_multi_size
./gemm_03_benchmark_validation
./gemm_04_heuristics
./gemm_05_json_export
./gemm_06_multi_registry
```
### Step 3: Run Python Examples
```bash
cd /path/to/composable_kernel/dispatcher
# GEMM
python3 examples/gemm/python/01_basic_gemm.py
python3 examples/gemm/python/04_validation.py
python3 examples/gemm/python/07_stress_test.py
python3 examples/gemm/python/08_heuristics.py
```
---
## Directory Structure
```
examples/
├── gemm/
│ ├── cpp/ # 6 C++ GEMM examples
│ └── python/ # 11 Python GEMM examples
└── README.md
```
---
## GEMM Examples
### C++ Examples
| # | Example | Description |
|---|---------|-------------|
| 01 | `gemm_01_basic` | Basic GEMM with declarative API, autofill, autocorrect |
| 02 | `gemm_02_multi_size` | Wildcard expansion for multiple configurations |
| 03 | `gemm_03_benchmark_validation` | Performance benchmarking with CPU/GPU validation |
| 04 | `gemm_04_heuristics` | Heuristic-based kernel selection |
| 05 | `gemm_05_json_export` | Registry JSON export for external tools |
| 06 | `gemm_06_multi_registry` | Multiple registries with named kernel sets |
**Details:** [gemm/cpp/README.md](gemm/cpp/README.md)
---
### Python Examples
| # | Example | Description |
|---|---------|-------------|
| 01 | `01_basic_gemm.py` | Basic GEMM with multi-kernel support |
| 02 | `02_batch_gemm.py` | Batched GEMM operations |
| 03 | `03_benchmark.py` | Performance benchmarking |
| 04 | `04_validation.py` | CPU reference validation |
| 05 | `05_numpy_integration.py` | NumPy array integration |
| 06 | `06_json_export.py` | Registry JSON export |
| 07 | `07_stress_test.py` | Multi-kernel stress testing (48 configs) |
| 08 | `08_heuristics.py` | Heuristic-based kernel selection (24 configs) |
| 09 | `09_multi_registry.py` | Multiple registries |
| 10 | `10_advanced_benchmark.py` | Advanced benchmark with full control |
| 11 | `11_json_import.py` | Import kernels from JSON |
**Details:** [gemm/python/README.md](gemm/python/README.md)
---
## Key Features
### Declarative Kernel API
Both C++ and Python examples use a declarative approach:
**C++ (DECL_KERNEL_SET macro):**
```cpp
DECL_KERNEL_SET(my_kernels,
.add(
Signature().dtype("fp16").layout("rcr"),
Algorithm().tile(256, 256, 32).wave(2, 2, 1).warp(32, 32, 16)
.pipeline("compv4").scheduler("intrawave"),
"gfx942"
)
);
```
**Python (KernelConfig):**
```python
config = KernelConfig(
tile_m=256, tile_n=256, tile_k=32,
wave_m=2, wave_n=2, wave_k=1,
warp_tile_m=32, warp_tile_n=32, warp_tile_k=16,
pipeline="compv4", scheduler="intrawave"
)
```
### Autofill and Autocorrect
The build system automatically:
- **Autofills** missing parameters with sensible defaults
- **Autocorrects** invalid parameters based on architecture constraints
- **Expands** wildcards (`*`, `-1`, `ANY_INT`) to all valid configurations
### Architecture Filtering
Kernel configurations are validated against GPU architecture constraints:
- Tile divisibility requirements
- Warp tile constraints
- Pipeline compatibility
Invalid configurations are automatically pruned during code generation.
---
## Validation Examples
### C++ Validation
```bash
./gemm_03_benchmark_validation --verify 1 # GEMM with CPU reference
./gemm_03_benchmark_validation --verify 2 # GEMM with GPU reference
```
### Python Validation
```bash
python3 examples/gemm/python/04_validation.py
python3 examples/gemm/python/07_stress_test.py # Multi-kernel validation
```
---
## Troubleshooting
### Python: Library not found
```bash
# Run from dispatcher directory
cd /path/to/composable_kernel/dispatcher
python3 examples/gemm/python/01_basic_gemm.py
```
### C++: Executables not found
```bash
# Build with examples enabled
cmake .. -DBUILD_DISPATCHER_EXAMPLES=ON
make -j$(nproc)
# Run from build/examples
cd build/examples
./gemm_01_basic
```
### GPU not detected
```bash
rocminfo | grep "Name:"
# Should show: gfx942, gfx90a, etc.
```
---
## Archived Examples
Convolution examples have been archived to `ck-2/conv_archive/dispatcher/`:
- `examples/conv/cpp/` - 11 C++ convolution examples
- `examples/conv/python/` - 14 Python convolution examples
See the archive for convolution functionality reference.

View File

@@ -0,0 +1,243 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
/**
* Example 01: Basic GEMM - Autofill, Autocorrect, and Full Declaration
*
* Demonstrates THREE declaration patterns:
*
* 1. AUTOFILL: Minimal declaration - missing params filled with defaults
* .add(Signature().dtype("fp16").layout("rcr"),
* Algorithm().tile(128,128,64).pipeline("compv3").scheduler("intrawave"),
* "gfx942")
* -> wave(2,2,1), warp(32,32,16), epilogue("cshuffle") added automatically
*
* 2. AUTOCORRECT: Invalid params corrected to valid values
* .add(..., Algorithm().wave(1,1,1)...)
* -> wave(1,1,1) is invalid for gfx942, corrected to wave(2,2,1)
*
* 3. FULL: All parameters explicitly specified
* .add(..., Algorithm().tile().wave().warp().pipeline().scheduler().epilogue()...)
*
* Build: cd dispatcher/build && cmake .. && make gemm_01_basic
*/
#include <hip/hip_runtime.h>
#include <iostream>
#include <iomanip>
#include <vector>
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/kernel_decl.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::backends;
using namespace ck_tile::dispatcher::utils;
using Signature = decl::Signature;
using Algorithm = decl::Algorithm;
// =============================================================================
// THREE KERNEL DECLARATION PATTERNS
// =============================================================================
DECL_KERNEL_SET(
basic_gemm_kernels,
// -------------------------------------------------------------------------
// Pattern 1: AUTOFILL - Minimal declaration
// Only specify: dtype, layout, tile, pipeline, scheduler
// Auto-filled: wave(2,2,1), warp(32,32,16), epilogue("cshuffle"), pad(false,false,false)
// -------------------------------------------------------------------------
.add(Signature().dtype("fp16").layout("rcr"),
Algorithm()
.tile(128, 128, 64) // Required
.pipeline("compv3") // Required
.scheduler("intrawave"), // Required
"gfx942")
// -------------------------------------------------------------------------
// Pattern 2: AUTOCORRECT - Invalid wave config
// wave(1,1,1) is invalid for gfx942 WMMA, corrected to wave(2,2,1)
// -------------------------------------------------------------------------
.add(Signature().dtype("fp16").layout("rcr"),
Algorithm()
.tile(128, 128, 32) // Different tile_k to make unique kernel
.wave(1, 1, 1) // INVALID: autocorrected to (2,2,1)
.warp(32, 32, 16) // Valid warp for 128x128 tile
.pipeline("compv3")
.scheduler("intrawave")
.epilogue("cshuffle"),
"gfx942")
// -------------------------------------------------------------------------
// Pattern 3: FULL - All parameters explicitly specified
// No autofill or autocorrect needed
// -------------------------------------------------------------------------
.add(Signature().dtype("fp16").layout("rcr"),
Algorithm()
.tile(64, 64, 32) // Explicit tile
.wave(2, 2, 1) // Explicit wave (valid)
.warp(16, 16, 32) // Explicit warp tile
.pipeline("compv3") // Explicit pipeline
.scheduler("intrawave") // Explicit scheduler
.epilogue("cshuffle") // Explicit epilogue
.pad(false, false, false), // Explicit padding
"gfx942"));
// =============================================================================
// MAIN
// =============================================================================
int main(int argc, char* argv[])
{
ExampleArgs args("Example 01: GEMM Autofill/Autocorrect/Full",
"Three kernel declaration patterns");
args.add_flag("--list", "List registered kernels");
args.add_flag("--list-verbose", "List registered kernels with full details");
args.add_option("--size", "1024", "Problem size MxNxK");
args.add_option("--arch", "gfx942", "GPU architecture");
if(!args.parse(argc, argv))
return 0;
print_header("Example 01: GEMM Declaration Patterns");
// =========================================================================
// Show the Three Patterns
// =========================================================================
std::cout << "\nTHREE DECLARATION PATTERNS:\n";
std::cout << "============================\n\n";
std::cout << "1. AUTOFILL (minimal declaration):\n";
std::cout << " .add(Signature().dtype(\"fp16\").layout(\"rcr\"),\n";
std::cout
<< " Algorithm().tile(128,128,64).pipeline(\"compv3\").scheduler(\"intrawave\"),\n";
std::cout << " \"gfx942\")\n";
std::cout << " -> Auto-filled: wave(2,2,1), warp(32,32,16), epilogue(\"cshuffle\")\n\n";
std::cout << "2. AUTOCORRECT (invalid params fixed):\n";
std::cout << " .add(..., Algorithm().wave(1,1,1)...)\n";
std::cout << " -> wave(1,1,1) invalid for gfx942, corrected to wave(2,2,1)\n\n";
std::cout << "3. FULL (all params explicit):\n";
std::cout << " .add(..., "
"Algorithm().tile().wave().warp().pipeline().scheduler().epilogue().pad()...)\n";
std::cout << " -> No changes needed\n\n";
std::string gfx_arch = args.get("--arch", "gfx942");
// =========================================================================
// Step 1: Show Declared Kernel Sets
// =========================================================================
std::cout << "Step 1: Declared Kernel Sets\n";
KernelSetRegistry::instance().print();
const auto& decl_set = KernelSetRegistry::instance().get("basic_gemm_kernels");
std::cout << " 'basic_gemm_kernels': " << decl_set.size() << " declaration(s)\n";
// =========================================================================
// Step 2: Create Registry and Register Kernels
// =========================================================================
std::cout << "\nStep 2: Register Kernels\n";
Registry registry;
// Use generic macro
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
std::cout << " Registered " << registry.size() << " kernel(s)\n";
// List kernels if requested
if(args.has("--list") || args.has("--list-verbose"))
{
std::cout << "\n";
print_registered_kernels(registry, std::cout, args.has("--list-verbose"));
return 0;
}
// =========================================================================
// Step 3: Create Dispatcher
// =========================================================================
std::cout << "\nStep 3: Create Dispatcher\n";
Dispatcher dispatcher(&registry);
// =========================================================================
// Step 4: Setup Problem
// =========================================================================
int size = args.get_int("--size", 1024);
const int M = size, N = size, K = size;
std::cout << "\nStep 4: Setup Problem (" << M << "x" << N << "x" << K << ")\n";
Problem problem(M, N, K);
using DataType = ck_tile::fp16_t;
GpuBuffer<DataType> a_dev(M * K);
GpuBuffer<DataType> b_dev(K * N);
GpuBuffer<DataType> c_dev(M * N);
std::vector<DataType> a_host(M * K, DataType(1.0f));
std::vector<DataType> b_host(K * N, DataType(1.0f));
a_dev.copy_from_host(a_host.data());
b_dev.copy_from_host(b_host.data());
c_dev.zero();
// =========================================================================
// Step 5: Select and Run
// =========================================================================
std::cout << "\nStep 5: Select and Run\n";
auto selected = dispatcher.select_kernel(problem);
if(!selected)
{
std::cerr << "ERROR: No kernel found!\n";
return 1;
}
std::cout << " Selected: " << selected->get_name() << "\n";
float time_ms = dispatcher.run(a_dev.get(), b_dev.get(), c_dev.get(), problem, nullptr);
std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n";
std::cout << " TFLOPS: " << std::setprecision(2) << calculate_tflops(M, N, K, time_ms) << "\n";
// =========================================================================
// Step 6: Verify
// =========================================================================
std::cout << "\nStep 6: Verify\n";
std::vector<DataType> c_host(M * N);
c_dev.copy_to_host(c_host.data());
const float expected = static_cast<float>(K);
int errors = 0;
for(int i = 0; i < M * N; ++i)
{
if(std::abs(static_cast<float>(c_host[i]) - expected) > 0.01f * expected + 1.0f)
++errors;
}
bool passed = (errors == 0);
std::cout << " Expected: " << expected << ", Errors: " << errors << "\n";
std::cout << " Status: " << (passed ? "PASS" : "FAIL") << "\n";
// =========================================================================
// Summary
// =========================================================================
print_separator();
std::cout << "DECLARATION PATTERNS SUMMARY:\n";
print_separator();
std::cout << R"(
1. AUTOFILL: Specify only required params, system fills defaults
- Useful for quick prototyping
- Guarantees valid configuration
2. AUTOCORRECT: System validates and fixes invalid params
- wave(1,1,1) -> wave(2,2,1) on gfx942
- Invalid pipeline/scheduler combos fixed
- Logs corrections for debugging
3. FULL: All params explicit - no changes made
- Full control over configuration
- Best for production/tuning
)";
print_separator();
return passed ? 0 : 1;
}

View File

@@ -0,0 +1,215 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
/**
* Example 02: Multi-Size GEMM with Wildcard Expansion
*
* Demonstrates the WILDCARD feature where specifying wildcards causes
* the build system to expand to ALL valid configurations for the architecture.
*
* WILDCARD SYNTAX:
* - Integer params: ANY_INT or -1 (both are equivalent, ANY_INT is just a #define for -1)
* - String params: "*" (for pipeline, scheduler)
*
* The kernel declaration:
* .add(..., Algorithm().tile(64,64,64).wave(ANY_INT,ANY_INT,1).warp(-1,-1,-1)
* .pipeline("*").scheduler("*"), ...)
*
* Expands to multiple kernels:
* - wave: (1,4,1), (2,2,1), (4,1,1) -> 3 options
* - warp: (16,16,32), (32,32,16) -> 2 options
* - pipeline: "compv3" -> 1 option (compv4 requires special handling)
* - scheduler: "intrawave" -> 1 option
*
* Raw expansion: 3 × 2 = 6 configs, but arch filter validates each:
* - tile_m must be divisible by (warp_m × warp_tile_m)
* - tile_n must be divisible by (warp_n × warp_tile_n)
* - Some wave/warp combos invalid: (4,1,1)+(32,32,16), (1,4,1)+(32,32,16)
* Result: 4 valid wildcard kernels + 1 explicit = 5 total
*
* Build: cd dispatcher/build && cmake .. && make gemm_02_multi_size
* Usage: ./gemm_02_multi_size [--max-size N] [--help]
*/
#include <hip/hip_runtime.h>
#include <iostream>
#include <iomanip>
#include <vector>
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/kernel_decl.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::utils;
using Signature = decl::Signature;
using Algorithm = decl::Algorithm;
// =============================================================================
// KERNEL SET: Demonstrates Wildcard Expansion
// =============================================================================
DECL_KERNEL_SET(multi_size_kernels,
// -------------------------------------------------------------------------
// Kernel 1: Explicit - all parameters specified (no expansion)
// -------------------------------------------------------------------------
.add(Signature().dtype("fp16").layout("rcr"),
Algorithm()
.tile(64, 64, 32)
.wave(2, 2, 1)
.warp(16, 16, 32)
.pipeline("compv3")
.scheduler("intrawave")
.epilogue("cshuffle"),
"gfx942")
// -------------------------------------------------------------------------
// Kernel 2: WILDCARD - expands to multiple valid configurations
// Wildcards: ANY_INT == -1 (for integers), "*" (for strings)
// -------------------------------------------------------------------------
.add(Signature().dtype("fp16").layout("rcr"),
Algorithm()
.tile(64, 64, 64)
.wave(ANY_INT, ANY_INT, 1) // ANY_INT → (1,4,1), (2,2,1), (4,1,1)
.warp(-1, -1, -1) // -1 same as ANY_INT → (16,16,32), (32,32,16)
.pipeline("*") // "*" → valid pipelines
.scheduler("*") // "*" → valid schedulers
.epilogue("cshuffle"),
"gfx942"));
// Raw: 3×2=6, arch filter removes 2 invalid → 4 valid kernels
// =============================================================================
// MAIN
// =============================================================================
int main(int argc, char* argv[])
{
ExampleArgs args("Example 02: Multi-Size GEMM with Wildcards",
"Demonstrates wildcard expansion for kernel generation");
args.add_option("--max-size", "4096", "Maximum problem size to test");
args.add_option("--arch", "gfx942", "GPU architecture");
args.add_flag("--list", "List all registered kernels");
args.add_flag("--list-verbose", "List kernels with full configuration details");
if(!args.parse(argc, argv))
return 0;
int max_size = args.get_int("--max-size", 4096);
std::string gfx_arch = args.get("--arch", "gfx942");
print_header("Example 02: Multi-Size GEMM with Wildcards");
// =========================================================================
// Show Wildcard Expansion Concept
// =========================================================================
std::cout << "\nWILDCARD EXPANSION:\n";
std::cout << "===================\n";
std::cout << R"(
Wildcard syntax:
ANY_INT or -1 -> expands integer params to all valid values
"*" -> expands string params (pipeline/scheduler) to valid values
Declaration with wildcards:
.tile(64, 64, 64) -> fixed tile size (no wildcard)
.wave(ANY_INT, ANY_INT, 1) -> expands to (1,4,1), (2,2,1), (4,1,1) = 3
.warp(-1, -1, -1) -> expands to (16,16,32), (32,32,16) = 2
.pipeline("*") -> expands to valid pipelines = 1
.scheduler("*") -> expands to valid schedulers = 1
Expanded: 3 × 2 = 6 configs, but arch filter validates each:
- wave×warp must divide tile: (4,1,1)×(32,32,16) invalid for 64x64
- Result: 4 valid kernels from wildcard + 1 explicit = 5 total
)";
// =========================================================================
// Setup Registry and Dispatcher
// =========================================================================
std::cout << "\nStep 1: Register Kernels\n";
std::cout << "------------------------\n";
Registry registry;
registry.set_name("multi_size_registry");
// Register kernels from generated header (includes expanded wildcards)
// Use generic macro - no need to hardcode example name
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
std::cout << " Registered " << registry.size() << " kernel(s) from wildcard expansion\n";
if(args.has("--list") || args.has("--list-verbose"))
{
std::cout << "\n";
print_registered_kernels(registry, std::cout, args.has("--list-verbose"));
return 0;
}
Dispatcher dispatcher(&registry);
std::cout << " Max size: " << max_size << "\n";
// =========================================================================
// Run Multiple Problem Sizes
// =========================================================================
std::cout << "\nStep 2: Run Multiple Sizes\n";
print_separator();
std::cout << std::setw(12) << "M" << std::setw(12) << "N" << std::setw(12) << "K"
<< std::setw(12) << "Time(ms)" << std::setw(12) << "TFLOPS" << "\n";
print_separator();
std::vector<std::tuple<int, int, int>> all_sizes = {
{256, 256, 256},
{512, 512, 512},
{1024, 1024, 1024},
{2048, 2048, 2048},
{4096, 4096, 4096},
};
std::vector<std::tuple<int, int, int>> sizes;
for(const auto& [M, N, K] : all_sizes)
{
if(std::max({M, N, K}) <= max_size)
sizes.push_back({M, N, K});
}
using DataType = ck_tile::fp16_t;
bool all_passed = true;
for(const auto& [M, N, K] : sizes)
{
Problem problem(M, N, K);
GpuBuffer<DataType> a_dev(M * K);
GpuBuffer<DataType> b_dev(K * N);
GpuBuffer<DataType> c_dev(M * N);
std::vector<DataType> a_host(M * K, DataType(1.0f));
std::vector<DataType> b_host(K * N, DataType(1.0f));
a_dev.copy_from_host(a_host.data());
b_dev.copy_from_host(b_host.data());
c_dev.zero();
float time_ms = dispatcher.run(a_dev.get(), b_dev.get(), c_dev.get(), problem, nullptr);
double tflops = calculate_tflops(M, N, K, time_ms);
std::cout << std::setw(12) << M << std::setw(12) << N << std::setw(12) << K << std::setw(12)
<< std::fixed << std::setprecision(4) << time_ms << std::setw(12)
<< std::setprecision(2) << tflops << "\n";
// Verify
std::vector<DataType> c_host(M * N);
c_dev.copy_to_host(c_host.data());
float expected = static_cast<float>(K);
int errors = 0;
for(int i = 0; i < M * N; ++i)
{
if(std::abs(static_cast<float>(c_host[i]) - expected) > 0.01f * expected + 1.0f)
++errors;
}
if(errors > 0)
all_passed = false;
}
print_separator();
std::cout << "Status: " << (all_passed ? "ALL PASSED" : "SOME FAILED") << "\n";
print_separator();
return all_passed ? 0 : 1;
}

View File

@@ -0,0 +1,344 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
/**
* Example 03: GEMM Benchmark & Validation
*
* Combined example demonstrating:
* 1. Benchmarking with statistics (warmup, iterations, min/max/mean/median)
* 2. Validation against CK Tile reference (CPU or GPU)
*
* Build: cd dispatcher/build && cmake .. && make gemm_03_benchmark_validation
* Usage: ./gemm_03_benchmark_validation [--size N] [--verify MODE] [--benchmark]
*
* Options:
* --size N Problem size MxNxK (default: 512)
* --verify MODE 0=none, 1=CPU ref, 2=GPU ref (default: 1)
* --benchmark Run full benchmark with statistics
* --warmup N Warmup iterations (default: 5)
* --iterations N Benchmark iterations (default: 20)
*/
#include <hip/hip_runtime.h>
#include <iostream>
#include <iomanip>
#include <vector>
#include <algorithm>
#include <numeric>
#include <cmath>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/kernel_decl.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::utils;
using namespace ck_tile::literals;
using Signature = decl::Signature;
using Algorithm = decl::Algorithm;
// =============================================================================
// KERNEL SET: High-performance kernels for benchmarking/validation
// =============================================================================
DECL_KERNEL_SET(benchmark_validation_kernels,
.add(Signature().dtype("fp16").layout("rcr"),
Algorithm()
.tile(128, 128, 32)
.wave(2, 2, 1)
.warp(32, 32, 16)
.pipeline("compv3")
.scheduler("intrawave")
.epilogue("cshuffle"),
"gfx942"));
// =============================================================================
// Helper: Layout detection
// =============================================================================
template <typename Layout>
constexpr auto is_row_major(Layout)
{
return ck_tile::bool_constant<std::is_same_v<Layout, ck_tile::tensor_layout::gemm::RowMajor>>{};
}
// =============================================================================
// MAIN
// =============================================================================
int main(int argc, char* argv[])
{
ExampleArgs args("Example 03: GEMM Benchmark & Validation",
"Benchmark and/or validate GEMM output against reference");
args.add_option("--size", "512", "Problem size MxNxK");
args.add_option("--verify", "1", "Verification: 0=none, 1=CPU ref, 2=GPU ref");
args.add_flag("--benchmark", "Run benchmark with statistics");
args.add_option("--warmup", "5", "Warmup iterations");
args.add_option("--iterations", "20", "Benchmark iterations");
args.add_option("--rtol", "0.01", "Relative tolerance");
args.add_option("--atol", "0.01", "Absolute tolerance");
args.add_option("--arch", "gfx942", "GPU architecture");
if(!args.parse(argc, argv))
return 0;
int M = args.get_int("--size", 512);
int N = M;
int K = M;
int verify = args.get_int("--verify", 1);
bool do_benchmark = args.has("--benchmark");
int warmup = args.get_int("--warmup", 5);
int iterations = args.get_int("--iterations", 20);
float rtol = args.get_float("--rtol", 0.01f);
float atol = args.get_float("--atol", 0.01f);
std::string gfx_arch = args.get("--arch", "gfx942");
print_header("Example 03: GEMM Benchmark & Validation");
std::cout << "\nConfiguration:\n";
std::cout << " Problem: " << M << " x " << N << " x " << K << "\n";
std::cout << " Layout: RCR (A=row, B=col, C=row)\n";
std::cout << " Verify: " << verify;
if(verify == 0)
std::cout << " (disabled)";
else if(verify == 1)
std::cout << " (CPU reference)";
else if(verify == 2)
std::cout << " (GPU reference)";
std::cout << "\n";
std::cout << " Benchmark: " << (do_benchmark ? "yes" : "no") << "\n";
if(do_benchmark)
{
std::cout << " Warmup: " << warmup << " iterations\n";
std::cout << " Measure: " << iterations << " iterations\n";
}
// =========================================================================
// Setup Registry and Dispatcher
// =========================================================================
Registry registry;
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
Dispatcher dispatcher(&registry);
std::cout << " Kernels: " << registry.size() << " registered\n";
print_registered_kernels(registry);
// =========================================================================
// Initialize data with proper tensor descriptors
// =========================================================================
using ALayout = ck_tile::tensor_layout::gemm::RowMajor;
using BLayout = ck_tile::tensor_layout::gemm::ColumnMajor;
using CLayout = ck_tile::tensor_layout::gemm::RowMajor;
using ADataType = ck_tile::fp16_t;
using BDataType = ck_tile::fp16_t;
using CDataType = ck_tile::fp16_t;
using AccDataType = float;
auto stride_a = ck_tile::get_default_stride(M, K, 0_uz, is_row_major(ALayout{}));
auto stride_b = ck_tile::get_default_stride(K, N, 0_uz, is_row_major(BLayout{}));
auto stride_c = ck_tile::get_default_stride(M, N, 0_uz, is_row_major(CLayout{}));
ck_tile::HostTensor<ADataType> a_m_k(
ck_tile::host_tensor_descriptor(M, K, stride_a, is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_k_n(
ck_tile::host_tensor_descriptor(K, N, stride_b, is_row_major(BLayout{})));
ck_tile::HostTensor<CDataType> c_m_n_dev(
ck_tile::host_tensor_descriptor(M, N, stride_c, is_row_major(CLayout{})));
ck_tile::HostTensor<CDataType> c_m_n_ref(
ck_tile::host_tensor_descriptor(M, N, stride_c, is_row_major(CLayout{})));
ck_tile::FillUniformDistribution<ADataType>{-0.5f, 0.5f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-0.5f, 0.5f}(b_k_n);
std::cout << "\nData:\n";
std::cout << " A: " << M << " x " << K << " (fp16, row-major)\n";
std::cout << " B: " << K << " x " << N << " (fp16, col-major)\n";
std::cout << " C: " << M << " x " << N << " (fp16, row-major)\n";
// GPU memory
ck_tile::DeviceMem a_dev(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_dev(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_dev(c_m_n_dev.get_element_space_size_in_bytes());
a_dev.ToDevice(a_m_k.data());
b_dev.ToDevice(b_k_n.data());
// =========================================================================
// Compute Reference (if needed)
// =========================================================================
if(verify > 0)
{
std::cout << "\nComputing reference...\n";
c_m_n_ref.SetZero();
if(verify == 1)
{
std::cout << " Using CPU reference (ck_tile::reference_gemm)\n";
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
a_m_k, b_k_n, c_m_n_ref);
}
else if(verify == 2)
{
std::cout << " Using GPU reference (ck_tile::reference_gemm_gpu)\n";
ck_tile::DeviceMem c_ref_dev(c_m_n_ref.get_element_space_size_in_bytes());
c_ref_dev.SetZero();
ck_tile::reference_gemm_gpu<ADataType,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(
static_cast<ADataType*>(a_dev.GetDeviceBuffer()),
static_cast<BDataType*>(b_dev.GetDeviceBuffer()),
static_cast<CDataType*>(c_ref_dev.GetDeviceBuffer()),
M,
N,
K,
stride_a,
stride_b,
stride_c);
(void)hipDeviceSynchronize();
c_ref_dev.FromDevice(c_m_n_ref.data());
}
std::cout << " Reference complete.\n";
}
// =========================================================================
// Run Kernel
// =========================================================================
Problem problem(M, N, K);
auto selected = dispatcher.select_kernel(problem);
std::cout << "\nRunning kernel:\n";
if(selected)
std::cout << " Selected: " << selected->get_name() << "\n";
c_dev.SetZero();
float time_ms = 0.0f;
std::vector<float> times;
if(do_benchmark)
{
// Warmup
std::cout << " Warming up (" << warmup << " iterations)...\n";
for(int i = 0; i < warmup; ++i)
{
c_dev.SetZero();
(void)dispatcher.run(static_cast<ADataType*>(a_dev.GetDeviceBuffer()),
static_cast<BDataType*>(b_dev.GetDeviceBuffer()),
static_cast<CDataType*>(c_dev.GetDeviceBuffer()),
problem,
nullptr);
}
// Benchmark
std::cout << " Benchmarking (" << iterations << " iterations)...\n";
times.reserve(iterations);
for(int i = 0; i < iterations; ++i)
{
c_dev.SetZero();
float t = dispatcher.run(static_cast<ADataType*>(a_dev.GetDeviceBuffer()),
static_cast<BDataType*>(b_dev.GetDeviceBuffer()),
static_cast<CDataType*>(c_dev.GetDeviceBuffer()),
problem,
nullptr);
times.push_back(t);
}
time_ms = *std::min_element(times.begin(), times.end());
}
else
{
// Single run
time_ms = dispatcher.run(static_cast<ADataType*>(a_dev.GetDeviceBuffer()),
static_cast<BDataType*>(b_dev.GetDeviceBuffer()),
static_cast<CDataType*>(c_dev.GetDeviceBuffer()),
problem,
nullptr);
}
c_dev.FromDevice(c_m_n_dev.data());
// =========================================================================
// Results
// =========================================================================
double flops = 2.0 * M * N * K;
double tflops = flops / (time_ms * 1e9);
print_separator();
std::cout << "Performance:\n";
print_separator();
if(do_benchmark && !times.empty())
{
std::sort(times.begin(), times.end());
float min_t = times.front();
float max_t = times.back();
float median_t = times[times.size() / 2];
float mean_t = std::accumulate(times.begin(), times.end(), 0.0f) / times.size();
std::cout << std::fixed << std::setprecision(4);
std::cout << " Min: " << min_t << " ms (" << std::setprecision(2)
<< (flops / (min_t * 1e9)) << " TFLOPS)\n";
std::cout << std::setprecision(4);
std::cout << " Max: " << max_t << " ms\n";
std::cout << " Mean: " << mean_t << " ms (" << std::setprecision(2)
<< (flops / (mean_t * 1e9)) << " TFLOPS)\n";
std::cout << std::setprecision(4);
std::cout << " Median: " << median_t << " ms (" << std::setprecision(2)
<< (flops / (median_t * 1e9)) << " TFLOPS)\n";
}
else
{
std::cout << std::fixed << std::setprecision(4);
std::cout << " Time: " << time_ms << " ms\n";
std::cout << std::setprecision(2);
std::cout << " TFLOPS: " << tflops << "\n";
}
// =========================================================================
// Validation
// =========================================================================
bool pass = true;
if(verify > 0)
{
print_separator();
std::cout << "Validation:\n";
print_separator();
std::cout << " Tolerance: rtol=" << rtol << ", atol=" << atol << "\n";
pass = ck_tile::check_err(c_m_n_dev, c_m_n_ref, "Validation Error!", rtol, atol);
float max_abs_diff = 0.0f;
float max_rel_diff = 0.0f;
for(size_t i = 0; i < c_m_n_dev.get_element_space_size(); ++i)
{
float dev_val = static_cast<float>(c_m_n_dev.mData[i]);
float ref_val = static_cast<float>(c_m_n_ref.mData[i]);
float abs_diff = std::abs(dev_val - ref_val);
float rel_diff = (ref_val != 0.0f) ? abs_diff / std::abs(ref_val) : abs_diff;
max_abs_diff = std::max(max_abs_diff, abs_diff);
max_rel_diff = std::max(max_rel_diff, rel_diff);
}
std::cout << " Max abs diff: " << max_abs_diff << "\n";
std::cout << " Max rel diff: " << max_rel_diff << "\n";
}
// =========================================================================
// Summary
// =========================================================================
print_separator();
std::cout << "Result: " << (pass ? "PASS" : "FAIL") << "\n";
print_separator();
return pass ? 0 : 1;
}

View File

@@ -0,0 +1,168 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
/**
* Example 04: Custom Heuristics
*
* Demonstrates custom kernel selection heuristics for different workloads.
*
* Build: cd dispatcher/build && cmake .. && make gemm_04_heuristics
*/
#include <hip/hip_runtime.h>
#include <iostream>
#include <iomanip>
#include <vector>
#include <algorithm>
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/kernel_decl.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::utils;
using Signature = decl::Signature;
using Algorithm = decl::Algorithm;
// =============================================================================
// KERNEL SET: Multiple tile sizes for heuristic-based selection
// =============================================================================
DECL_KERNEL_SET(heuristics_kernels,
// Small tile - low latency
.add(Signature().dtype("fp16").layout("rcr"),
Algorithm()
.tile(64, 64, 32)
.wave(2, 2, 1)
.warp(32, 32, 16)
.pipeline("compv3")
.scheduler("intrawave")
.epilogue("cshuffle"),
"gfx942")
// Medium tile - balanced
.add(Signature().dtype("fp16").layout("rcr"),
Algorithm()
.tile(128, 128, 64)
.wave(2, 2, 1)
.warp(32, 32, 16)
.pipeline("compv3")
.scheduler("intrawave")
.epilogue("cshuffle"),
"gfx942"));
// =============================================================================
// Custom Heuristic
// =============================================================================
std::vector<std::string> size_based_heuristic(const Problem& problem)
{
std::vector<std::string> ranked_kernels;
int64_t total_elements = problem.M * problem.N;
if(total_elements < 100000)
{
ranked_kernels = {"gemm_64x64", "gemm_128x128"};
}
else
{
ranked_kernels = {"gemm_128x128", "gemm_64x64"};
}
return ranked_kernels;
}
// =============================================================================
// MAIN
// =============================================================================
int main(int argc, char* argv[])
{
ExampleArgs args("Example 04: Custom Heuristics",
"Demonstrates custom kernel selection heuristics");
args.add_option("--arch", "gfx942", "GPU architecture");
if(!args.parse(argc, argv))
return 0;
print_header("Example 04: Custom Heuristics");
std::string gfx_arch = args.get("--arch", "gfx942");
// =========================================================================
// Setup Registry and Dispatcher
// =========================================================================
Registry registry;
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
Dispatcher dispatcher(&registry);
dispatcher.set_strategy(Dispatcher::SelectionStrategy::Heuristic);
dispatcher.set_heuristic(size_based_heuristic);
std::cout << "\nSetup:\n";
std::cout << " Registry: " << registry.size() << " kernel(s)\n";
std::cout << " Strategy: Heuristic (size-based)\n";
// =========================================================================
// Test Different Problem Sizes
// =========================================================================
std::cout << "\nTesting heuristic selection:\n";
print_separator();
using DataType = ck_tile::fp16_t;
std::vector<std::tuple<int, int, int>> sizes = {
{128, 128, 64},
{512, 512, 256},
{2048, 2048, 1024},
};
bool all_passed = true;
for(const auto& [M, N, K] : sizes)
{
Problem problem(M, N, K);
auto selected = dispatcher.select_kernel(problem);
std::cout << "Problem " << M << "x" << N << "x" << K << ":\n";
if(selected)
{
std::cout << " Selected: " << selected->get_name() << "\n";
}
GpuBuffer<DataType> a_dev(M * K);
GpuBuffer<DataType> b_dev(K * N);
GpuBuffer<DataType> c_dev(M * N);
std::vector<DataType> a_host(M * K, DataType(1.0f));
std::vector<DataType> b_host(K * N, DataType(1.0f));
a_dev.copy_from_host(a_host.data());
b_dev.copy_from_host(b_host.data());
c_dev.zero();
float time_ms = dispatcher.run(a_dev.get(), b_dev.get(), c_dev.get(), problem, nullptr);
double tflops = calculate_tflops(M, N, K, time_ms);
std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n";
std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n";
// Verify
std::vector<DataType> c_host(M * N);
c_dev.copy_to_host(c_host.data());
float expected = static_cast<float>(K);
int errors = 0;
for(int i = 0; i < M * N; ++i)
{
float actual = static_cast<float>(c_host[i]);
if(std::abs(actual - expected) > 0.01f * expected + 1.0f)
++errors;
}
bool pass = (errors == 0);
std::cout << " Verify: " << (pass ? "PASS" : "FAIL") << "\n";
if(!pass)
all_passed = false;
print_separator();
}
std::cout << "Overall: " << (all_passed ? "ALL PASSED" : "SOME FAILED") << "\n";
return all_passed ? 0 : 1;
}

View File

@@ -0,0 +1,127 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
/**
* Example 05: JSON Export
*
* Demonstrates exporting registry information to JSON format.
*
* Build: cd dispatcher/build && cmake .. && make gemm_05_json_export
*/
#include <hip/hip_runtime.h>
#include <iostream>
#include <fstream>
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/kernel_decl.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::utils;
using Signature = decl::Signature;
using Algorithm = decl::Algorithm;
// =============================================================================
// KERNEL SET: Multiple kernels for JSON export demo
// =============================================================================
DECL_KERNEL_SET(json_export_kernels,
.add(Signature().dtype("fp16").layout("rcr"),
Algorithm()
.tile(64, 64, 32)
.wave(2, 2, 1)
.warp(32, 32, 16)
.pipeline("compv3")
.scheduler("intrawave")
.epilogue("cshuffle"),
"gfx942")
.add(Signature().dtype("fp16").layout("rcr"),
Algorithm()
.tile(128, 128, 64)
.wave(2, 2, 1)
.warp(32, 32, 16)
.pipeline("compv3")
.scheduler("intrawave")
.epilogue("cshuffle"),
"gfx942"));
// =============================================================================
// MAIN
// =============================================================================
int main(int argc, char* argv[])
{
ExampleArgs args("Example 05: JSON Export", "Export registry information to JSON format");
args.add_option("--output", "registry.json", "Output JSON file path");
args.add_option("--arch", "gfx942", "GPU architecture");
args.add_flag("--list", "List all kernel sets");
if(!args.parse(argc, argv))
return 0;
print_header("Example 05: JSON Export");
std::string gfx_arch = args.get("--arch", "gfx942");
if(args.has("--list"))
{
std::cout << "\nDeclared Kernel Sets:\n";
KernelSetRegistry::instance().print();
return 0;
}
std::string output_file = args.get("--output", "registry.json");
// =========================================================================
// Setup Registry
// =========================================================================
std::cout << "\nSetting up registry...\n";
Registry registry;
registry.set_name("json_export_registry");
REGISTER_GENERATED_KERNELS(registry, gfx_arch);
std::cout << " Registry: " << registry.get_name() << "\n";
std::cout << " Kernels: " << registry.size() << "\n";
// =========================================================================
// Export to JSON
// =========================================================================
std::cout << "\nExporting to JSON...\n";
std::string json = registry.export_json(true);
std::cout << "\nJSON Preview (first 500 chars):\n";
print_separator();
std::cout << json.substr(0, std::min(size_t(500), json.size()));
if(json.size() > 500)
std::cout << "\n...";
std::cout << "\n";
print_separator();
// Write to file
std::ofstream file(output_file);
if(file.is_open())
{
file << json;
file.close();
std::cout << "\nExported to: " << output_file << "\n";
std::cout << "File size: " << json.size() << " bytes\n";
}
else
{
std::cerr << "Failed to write to: " << output_file << "\n";
return 1;
}
// =========================================================================
// Also show kernel set declarations
// =========================================================================
std::cout << "\nKernel Set Declarations:\n";
print_separator();
KernelSetRegistry::instance().print();
print_separator();
return 0;
}

View File

@@ -0,0 +1,294 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
/**
* Example 06: Multiple Registries and Multiple Kernel Sets
*
* Demonstrates:
* - Multiple DECL_KERNEL_SET declarations (each with multiple kernels)
* - Separate Registry instances for different workload types
* - Independent Dispatchers that select from their respective registries
*
* Registration patterns:
* - REGISTER_GENERATED_KERNELS(registry, arch) -> all kernels to one registry
* - REGISTER_KERNEL_SET("set_name", registry, arch) -> specific set by name
* - generated::get_kernel_set_names() -> list available set names
*
* Build: cd dispatcher/build && cmake .. && make gemm_06_multi_registry
* Usage: ./gemm_06_multi_registry [--list] [--help]
*/
#include <hip/hip_runtime.h>
#include <iostream>
#include <iomanip>
#include <vector>
#include "ck_tile/dispatcher.hpp"
#include "ck_tile/dispatcher/kernel_decl.hpp"
#include "ck_tile/dispatcher/example_args.hpp"
using namespace ck_tile::dispatcher;
using namespace ck_tile::dispatcher::utils;
using Signature = decl::Signature;
using Algorithm = decl::Algorithm;
// =============================================================================
// KERNEL SETS: Multiple sets with multiple kernels each
// =============================================================================
// Compute-bound kernel set: Large tiles for high arithmetic intensity
// Max tile with 32x32 warp is 128x128 (16 warps = 1024 threads)
DECL_KERNEL_SET(compute_bound_set,
.add(Signature().dtype("fp16").layout("rcr"),
Algorithm()
.tile(128, 128, 64) // Large tile, max for 32x32 warp
.wave(2, 2, 1)
.warp(32, 32, 16)
.pipeline("compv3")
.scheduler("intrawave")
.epilogue("cshuffle"),
"gfx942")
.add(Signature().dtype("fp16").layout("rcr"),
Algorithm()
.tile(128, 128, 32) // Same tile, different K for variety
.wave(2, 2, 1)
.warp(32, 32, 16)
.pipeline("compv3")
.scheduler("intrawave")
.epilogue("cshuffle"),
"gfx942"));
// Memory-bound kernel set: Smaller tiles for better cache efficiency
DECL_KERNEL_SET(memory_bound_set,
.add(Signature().dtype("fp16").layout("rcr"),
Algorithm()
.tile(64, 64, 32)
.wave(2, 2, 1)
.warp(32, 32, 16)
.pipeline("compv3")
.scheduler("intrawave")
.epilogue("cshuffle"),
"gfx942")
.add(Signature().dtype("fp16").layout("rcr"),
Algorithm()
.tile(128, 64, 32)
.wave(2, 2, 1)
.warp(32, 32, 16)
.pipeline("compv3")
.scheduler("intrawave")
.epilogue("cshuffle"),
"gfx942"));
// Latency-optimized: Minimal overhead tiles
DECL_KERNEL_SET(latency_set,
.add(Signature().dtype("fp16").layout("rcr"),
Algorithm()
.tile(64, 64, 64)
.wave(2, 2, 1)
.warp(32, 32, 16)
.pipeline("compv3")
.scheduler("intrawave")
.epilogue("cshuffle"),
"gfx942"));
// =============================================================================
// MAIN
// =============================================================================
int main(int argc, char* argv[])
{
ExampleArgs args("Example 06: Multiple Registries",
"Separate registries for different workload types");
args.add_flag("--list", "List all declared kernel sets");
args.add_option("--arch", "gfx942", "GPU architecture");
if(!args.parse(argc, argv))
return 0;
print_header("Example 06: Multiple Registries & Kernel Sets");
std::string gfx_arch = args.get("--arch", "gfx942");
// =========================================================================
// Step 1: Show declared kernel sets (from DECL_KERNEL_SET macros)
// =========================================================================
std::cout << "\nStep 1: Declared Kernel Sets\n";
std::cout << "-----------------------------\n";
KernelSetRegistry::instance().print();
if(args.has("--list"))
{
// Print detailed info
for(const auto& name : KernelSetRegistry::instance().names())
{
const auto& set = KernelSetRegistry::instance().get(name);
std::cout << "\n " << name << ":\n";
for(const auto& decl : set.declarations())
{
std::cout << " - " << decl.name() << " (tile=" << decl.algorithm.tile_m_ << "x"
<< decl.algorithm.tile_n_ << "x" << decl.algorithm.tile_k_ << ")\n";
}
}
return 0;
}
// =========================================================================
// Step 2: Create registries and demonstrate MERGING
// =========================================================================
std::cout << "\nStep 2: Create and Merge Registries\n";
std::cout << "------------------------------------\n";
// Create individual registries first
Registry compute_registry;
Registry latency_registry;
Registry memory_registry;
compute_registry.set_name("compute_bound");
latency_registry.set_name("latency_optimized");
memory_registry.set_name("memory_bound");
// Register kernels to individual registries using set names (no hardcoding)
REGISTER_KERNEL_SET("compute_bound_set", compute_registry, gfx_arch);
REGISTER_KERNEL_SET("latency_set", latency_registry, gfx_arch);
REGISTER_KERNEL_SET("memory_bound_set", memory_registry, gfx_arch);
std::cout << " Individual registries:\n";
std::cout << " compute_bound: " << compute_registry.size() << " kernel(s)\n";
std::cout << " latency_optimized: " << latency_registry.size() << " kernel(s)\n";
std::cout << " memory_bound: " << memory_registry.size() << " kernel(s)\n";
// MERGE compute + latency into a combined registry
Registry combined_registry;
combined_registry.set_name("compute_latency_combined");
// Register both sets into combined registry
REGISTER_KERNEL_SET("compute_bound_set", combined_registry, gfx_arch);
REGISTER_KERNEL_SET("latency_set", combined_registry, gfx_arch);
std::cout << "\n After merging compute + latency:\n";
std::cout << " combined: " << combined_registry.size() << " kernel(s)\n";
std::cout << " memory (separate): " << memory_registry.size() << " kernel(s)\n";
// =========================================================================
// Step 3: Create dispatchers - one merged, one separate
// =========================================================================
std::cout << "\nStep 3: Create Dispatchers\n";
std::cout << "--------------------------\n";
Dispatcher combined_dispatcher(&combined_registry); // compute + latency merged
Dispatcher memory_dispatcher(&memory_registry); // memory separate
std::cout << " combined_dispatcher: compute + latency kernels (" << combined_registry.size()
<< " kernels)\n";
std::cout << " memory_dispatcher: memory-bound kernels (" << memory_registry.size()
<< " kernels)\n";
// =========================================================================
// Step 4: Run with different dispatchers
// =========================================================================
std::cout << "\nStep 4: Run Workloads\n";
print_separator();
using DataType = ck_tile::fp16_t;
struct WorkloadTest
{
const char* name;
Dispatcher* dispatcher;
int M, N, K;
};
std::vector<WorkloadTest> tests = {
{"Compute-bound (combined)", &combined_dispatcher, 4096, 4096, 4096},
{"Memory-bound (separate)", &memory_dispatcher, 1024, 1024, 1024},
{"Latency-opt (combined)", &combined_dispatcher, 512, 512, 512},
};
bool all_passed = true;
for(const auto& test : tests)
{
Problem problem(test.M, test.N, test.K);
// Allocate and initialize
GpuBuffer<DataType> a_dev(test.M * test.K);
GpuBuffer<DataType> b_dev(test.K * test.N);
GpuBuffer<DataType> c_dev(test.M * test.N);
std::vector<DataType> a_host(test.M * test.K, DataType(1.0f));
std::vector<DataType> b_host(test.K * test.N, DataType(1.0f));
a_dev.copy_from_host(a_host.data());
b_dev.copy_from_host(b_host.data());
c_dev.zero();
// Select kernel and run
auto selected = test.dispatcher->select_kernel(problem);
float time_ms =
test.dispatcher->run(a_dev.get(), b_dev.get(), c_dev.get(), problem, nullptr);
double tflops = calculate_tflops(test.M, test.N, test.K, time_ms);
std::cout << test.name << " (" << test.M << "x" << test.N << "x" << test.K << "):\n";
if(selected)
std::cout << " Selected: " << selected->get_name() << "\n";
std::cout << " Time: " << std::fixed << std::setprecision(4) << time_ms << " ms\n";
std::cout << " TFLOPS: " << std::setprecision(2) << tflops << "\n";
// Verify ALL elements
std::vector<DataType> c_host(test.M * test.N);
c_dev.copy_to_host(c_host.data());
const float expected = static_cast<float>(test.K);
int num_errors = 0;
float max_error = 0.0f;
for(int i = 0; i < test.M * test.N; ++i)
{
float actual = static_cast<float>(c_host[i]);
float error = std::abs(actual - expected);
max_error = std::max(max_error, error);
// Allow 1% relative tolerance for FP16 accumulation
if(error > 0.01f * expected + 1.0f)
++num_errors;
}
bool test_passed = (num_errors == 0);
std::cout << " Verify: " << (test.M * test.N) << " elements, errors=" << num_errors
<< "\n";
std::cout << " Status: " << (test_passed ? "PASS" : "FAIL") << "\n\n";
if(!test_passed)
all_passed = false;
}
// =========================================================================
// Summary
// =========================================================================
print_separator();
std::cout << "Multi-Registry Pattern Summary:\n";
print_separator();
std::cout << R"(
// 1. Declare multiple kernel sets
DECL_KERNEL_SET(compute_bound_set, .add(...));
DECL_KERNEL_SET(memory_bound_set, .add(...));
DECL_KERNEL_SET(latency_set, .add(...));
// 2. Create registries and register by set NAME (no hardcoding!)
Registry combined_reg, memory_reg;
REGISTER_KERNEL_SET("compute_bound_set", combined_reg, arch); // Add compute
REGISTER_KERNEL_SET("latency_set", combined_reg, arch); // Merge latency
REGISTER_KERNEL_SET("memory_bound_set", memory_reg, arch); // Separate
// 3. Create dispatchers from merged/separate registries
Dispatcher combined_disp(&combined_reg); // Has both compute + latency
Dispatcher memory_disp(&memory_reg); // Has only memory-bound
// 4. Choose dispatcher based on workload
if (problem.is_memory_bound())
memory_disp.run(...);
else
combined_disp.run(...); // Handles both compute & latency workloads
)";
print_separator();
std::cout << "Overall Status: " << (all_passed ? "ALL PASSED" : "SOME FAILED") << "\n";
return all_passed ? 0 : 1;
}

View File

@@ -0,0 +1,229 @@
# GEMM C++ Examples
CK Tile Dispatcher C++ examples for GEMM (General Matrix Multiplication) operations.
> **Main Documentation**: [Dispatcher README](../../../README.md) | [Examples Overview](../../README.md)
## Quick Start
### Build and Run
```bash
cd /path/to/composable_kernel/dispatcher
mkdir -p build && cd build
cmake .. \
-DCMAKE_PREFIX_PATH=/opt/rocm \
-DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-DBUILD_DISPATCHER_EXAMPLES=ON
# Build (kernels generated automatically by CMake)
make -j$(nproc)
# Run examples
cd examples
./gemm_01_basic
./gemm_03_benchmark_validation
./gemm_04_heuristics
```
## Examples
| Example | Description | Complexity |
|---------|-------------|------------|
| [01_basic_gemm.cpp](01_basic_gemm.cpp) | Basic GEMM with declarative API, autofill, autocorrect | ★☆☆☆☆ |
| [02_multi_size.cpp](02_multi_size.cpp) | Wildcard expansion for multiple configurations | ★★☆☆☆ |
| [03_benchmark_validation.cpp](03_benchmark_validation.cpp) | Performance benchmarking with CPU reference validation | ★★☆☆☆ |
| [04_heuristics.cpp](04_heuristics.cpp) | Heuristic-based kernel selection | ★★★☆☆ |
| [05_json_export.cpp](05_json_export.cpp) | Registry JSON export for external tools | ★★☆☆☆ |
| [06_multi_registry.cpp](06_multi_registry.cpp) | Multiple registries with named kernel sets | ★★★☆☆ |
## Example Details
### 01_basic_gemm.cpp - Basic GEMM
Demonstrates the declarative kernel API with three patterns:
1. **Autofill Pattern** - Minimal specification, defaults filled automatically
2. **Autocorrect Pattern** - Invalid parameters corrected at build time
3. **Full Specification Pattern** - Complete kernel configuration
```cpp
DECL_KERNEL_SET(basic_kernels,
// Pattern 1: Autofill - minimal specification
.add(
Signature().dtype("fp16").layout("rcr"),
Algorithm(), // Defaults filled by autofill
"gfx942"
)
// Pattern 2: Full specification
.add(
Signature().dtype("fp16").layout("rcr"),
Algorithm().tile(256, 256, 32).wave(2, 2, 1).warp(32, 32, 16)
.pipeline("compv4").scheduler("intrawave"),
"gfx942"
)
);
```
**Features:**
- Uses generic `REGISTER_GENERATED_KERNELS` macro
- `print_registered_kernels()` utility for debugging
- Demonstrates autofill messages during build
### 02_multi_size.cpp - Wildcard Expansion
Demonstrates automatic generation of multiple kernel configurations:
```cpp
DECL_KERNEL_SET(multi_kernels,
.add(
Signature().dtype("fp16").layout("rcr"),
Algorithm().tile(*, *, 32) // Wildcard tile M and N
.wave(2, 2, 1)
.warp(32, 32, 16)
.pipeline("compv4")
.scheduler("intrawave"),
"gfx942"
)
);
```
**Wildcard Values:**
- `*`, `-1`, or `ANY_INT` expand to all valid configurations
- Architecture filter prunes invalid combinations automatically
- Example generates 5 valid kernels after arch filtering (from 7 expansions)
### 03_benchmark_validation.cpp - Benchmark + Validation
Consolidated example combining performance benchmarking with correctness validation:
```bash
# Benchmark only
./gemm_03_benchmark_validation --warmup 10 --iterations 100
# With CPU validation
./gemm_03_benchmark_validation --verify 1 --rtol 1e-3 --atol 1e-3
# With GPU reference validation (faster for large matrices)
./gemm_03_benchmark_validation --verify 2
```
**Features:**
- Warmup iterations (discarded from timing)
- Benchmark iterations with statistics (min/max/mean/median)
- CPU reference validation using `ck_tile::reference_gemm`
- GPU reference validation using `ck_tile::reference_gemm_gpu`
- Configurable tolerances
### 04_heuristics.cpp - Heuristic Selection
Demonstrates custom kernel selection based on problem characteristics:
```cpp
// Problem size analysis
auto heuristic = [](const Problem& p) -> std::optional<KernelKey> {
if (p.M() * p.N() < 256 * 256) {
return small_kernel_key; // Memory-bound heuristic
} else {
return large_kernel_key; // Compute-bound heuristic
}
};
dispatcher.set_heuristic(heuristic);
```
**Features:**
- Problem size analysis (small vs large matrices)
- Compute-bound vs memory-bound selection
- Custom heuristic function registration
### 05_json_export.cpp - JSON Export
Exports registry information to JSON for external tool integration:
```cpp
auto json = registry.to_json();
std::ofstream file("kernels.json");
file << json;
```
**Use Cases:**
- Kernel metadata serialization
- External analysis tools
- Configuration management
### 06_multi_registry.cpp - Multiple Registries
Demonstrates using multiple registries with named kernel sets:
```cpp
// Define separate kernel sets
DECL_KERNEL_SET(compute_optimized, ...);
DECL_KERNEL_SET(latency_optimized, ...);
// Register to specific registries
Registry compute_registry, latency_registry;
REGISTER_KERNEL_SET(compute_optimized, compute_registry);
REGISTER_KERNEL_SET(latency_optimized, latency_registry);
// Use appropriate registry based on workload
Dispatcher compute_dispatcher(compute_registry);
Dispatcher latency_dispatcher(latency_registry);
```
**Features:**
- Named kernel set registration with `REGISTER_KERNEL_SET` macro
- Separate registries for different optimization goals
- Dynamic kernel set selection by name
## Benchmark Parameters (stream_config)
CK Tile uses `stream_config` for benchmark control:
```cpp
ck_tile::stream_config cfg{
nullptr, // stream_id - HIP stream (nullptr = default)
true, // time_kernel - Enable timing
1, // log_level - Verbosity (0=quiet, 1=normal)
5, // cold_niters - Warmup iterations
20, // nrepeat - Benchmark iterations
true, // is_gpu_timer - Use GPU events vs CPU chrono
false, // flush_cache - Flush L2 cache between iterations
1 // rotating_count - Rotating buffers for cache simulation
};
```
| Parameter | CLI Option | Default | Description |
|-----------|------------|---------|-------------|
| `cold_niters_` | `--warmup` | 5 | Warmup iterations |
| `nrepeat_` | `--iterations` | 100 | Benchmark iterations |
| `flush_cache_` | - | false | Flush L2 cache |
| `rotating_count_` | - | 1 | Rotating buffers |
| `is_gpu_timer_` | - | true | GPU timer vs CPU |
## Declarative Kernel Pattern
All examples use the declarative `DECL_KERNEL_SET` macro:
```cpp
DECL_KERNEL_SET(my_kernels,
.add(
Signature() // WHAT: operation signature
.dtype("fp16") // Data type
.layout("rcr"), // Matrix layouts (A=row, B=col, C=row)
Algorithm() // HOW: implementation details
.tile(256, 256, 32) // Tile sizes (M, N, K)
.wave(2, 2, 1) // Wave configuration
.warp(32, 32, 16) // Warp tile sizes
.pipeline("compv4") // Pipeline type
.scheduler("intrawave"), // Scheduler type
"gfx942" // WHERE: target architecture
)
);
```
**Key Macros:**
- `DECL_KERNEL_SET(name, ...)` - Declare a kernel set
- `REGISTER_GENERATED_KERNELS` - Register all kernels from this example
- `REGISTER_KERNEL_SET(name, registry)` - Register specific kernel set to a registry
## Related Documentation
- [Python GEMM Examples](../python/README.md)
- [Convolution Examples](../../conv/cpp/README.md)
- [Main Dispatcher README](../../../README.md)

View File

@@ -0,0 +1,331 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 01: Basic GEMM with Multiple Kernels
Demonstrates:
1. Declaring multiple kernel configurations
2. Printing all registered kernels
3. Running each kernel and validating output
4. Comparing performance across kernels
Complexity: ★★☆☆☆
Usage:
python3 01_basic_gemm.py
python3 01_basic_gemm.py --help
python3 01_basic_gemm.py --dtype bf16
python3 01_basic_gemm.py --size 2048
"""
import sys
import argparse
from pathlib import Path
from dataclasses import dataclass
from typing import List
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from ctypes_utils import (
KernelConfig,
setup_gemm_dispatcher,
cleanup_gemm,
reset_for_example,
)
@dataclass
class KernelSpec:
"""Specification for a kernel configuration"""
name: str
tile_m: int
tile_n: int
tile_k: int
pipeline: str = "compv3"
scheduler: str = "intrawave"
# Define multiple kernel configurations to test (50+ kernels)
KERNEL_SPECS = [
# Small tiles - compv3
KernelSpec("small_64x64_k32", 64, 64, 32, "compv3"),
KernelSpec("small_64x64_k64", 64, 64, 64, "compv3"),
# Small tiles - compv4
KernelSpec("small_64x64_v4_k32", 64, 64, 32, "compv4"),
KernelSpec("small_64x64_v4_k64", 64, 64, 64, "compv4"),
# Medium tiles - compv3
KernelSpec("med_128x128_k32", 128, 128, 32, "compv3"),
KernelSpec("med_128x128_k64", 128, 128, 64, "compv3"),
KernelSpec("med_128x128_k128", 128, 128, 128, "compv3"),
# Medium tiles - compv4
KernelSpec("med_128x128_v4_k32", 128, 128, 32, "compv4"),
KernelSpec("med_128x128_v4_k64", 128, 128, 64, "compv4"),
KernelSpec("med_128x128_v4_k128", 128, 128, 128, "compv4"),
# Rectangular tiles - compv3
KernelSpec("rect_64x128_k32", 64, 128, 32, "compv3"),
KernelSpec("rect_64x128_k64", 64, 128, 64, "compv3"),
KernelSpec("rect_128x64_k32", 128, 64, 32, "compv3"),
KernelSpec("rect_128x64_k64", 128, 64, 64, "compv3"),
# Rectangular tiles - compv4
KernelSpec("rect_64x128_v4_k32", 64, 128, 32, "compv4"),
KernelSpec("rect_64x128_v4_k64", 64, 128, 64, "compv4"),
KernelSpec("rect_128x64_v4_k32", 128, 64, 32, "compv4"),
KernelSpec("rect_128x64_v4_k64", 128, 64, 64, "compv4"),
# Large tiles - compv3
KernelSpec("large_256x128_k32", 256, 128, 32, "compv3"),
KernelSpec("large_256x128_k64", 256, 128, 64, "compv3"),
KernelSpec("large_128x256_k32", 128, 256, 32, "compv3"),
KernelSpec("large_128x256_k64", 128, 256, 64, "compv3"),
KernelSpec("large_256x256_k32", 256, 256, 32, "compv3"),
KernelSpec("large_256x256_k64", 256, 256, 64, "compv3"),
# Large tiles - compv4
KernelSpec("large_256x128_v4_k32", 256, 128, 32, "compv4"),
KernelSpec("large_256x128_v4_k64", 256, 128, 64, "compv4"),
KernelSpec("large_128x256_v4_k32", 128, 256, 32, "compv4"),
KernelSpec("large_128x256_v4_k64", 128, 256, 64, "compv4"),
KernelSpec("large_256x256_v4_k32", 256, 256, 32, "compv4"),
KernelSpec("large_256x256_v4_k64", 256, 256, 64, "compv4"),
# Interwave scheduler variants
KernelSpec("int_64x64_k32", 64, 64, 32, "compv3", "interwave"),
KernelSpec("int_128x128_k32", 128, 128, 32, "compv3", "interwave"),
KernelSpec("int_128x128_k64", 128, 128, 64, "compv3", "interwave"),
KernelSpec("int_256x128_k32", 256, 128, 32, "compv3", "interwave"),
# More tile_k variations - compv3
KernelSpec("med_128x128_k16", 128, 128, 16, "compv3"),
KernelSpec("rect_64x128_k16", 64, 128, 16, "compv3"),
KernelSpec("rect_128x64_k16", 128, 64, 16, "compv3"),
# More tile_k variations - compv4
KernelSpec("med_128x128_v4_k16", 128, 128, 16, "compv4"),
KernelSpec("rect_64x128_v4_k16", 64, 128, 16, "compv4"),
KernelSpec("rect_128x64_v4_k16", 128, 64, 16, "compv4"),
# Additional rectangular
KernelSpec("rect_32x64_k32", 32, 64, 32, "compv3"),
KernelSpec("rect_64x32_k32", 64, 32, 32, "compv3"),
KernelSpec("rect_32x128_k32", 32, 128, 32, "compv3"),
KernelSpec("rect_128x32_k32", 128, 32, 32, "compv3"),
# Additional compv4 variants
KernelSpec("rect_32x64_v4_k32", 32, 64, 32, "compv4"),
KernelSpec("rect_64x32_v4_k32", 64, 32, 32, "compv4"),
KernelSpec("rect_32x128_v4_k32", 32, 128, 32, "compv4"),
KernelSpec("rect_128x32_v4_k32", 128, 32, 32, "compv4"),
]
def create_kernel_config(spec: KernelSpec, dtype: str, arch: str) -> KernelConfig:
"""Create a KernelConfig from a spec"""
# Adjust warp tiles based on tile size
if spec.tile_m <= 64:
warp_m, warp_n = 16, 16
else:
warp_m, warp_n = 32, 32
return KernelConfig(
dtype_a=dtype,
dtype_b=dtype,
dtype_c=dtype,
dtype_acc="fp32",
layout_a="row",
layout_b="col",
layout_c="row",
tile_m=spec.tile_m,
tile_n=spec.tile_n,
tile_k=spec.tile_k,
wave_m=2,
wave_n=2,
wave_k=1,
warp_m=warp_m,
warp_n=warp_n,
warp_k=16,
pipeline=spec.pipeline,
scheduler=spec.scheduler,
epilogue="cshuffle",
gfx_arch=arch,
)
def print_kernel_table(specs: List[KernelSpec], dtype: str):
"""Print a formatted table of kernel configurations"""
print("\n" + "=" * 70)
print(f" DECLARED KERNEL CONFIGURATIONS ({len(specs)} kernels)")
print("=" * 70)
print(f"\n {'#':<3} {'Name':<18} {'Tile':<14} {'Pipeline':<10} {'Scheduler':<12}")
print(" " + "-" * 68)
for i, spec in enumerate(specs, 1):
tile = f"{spec.tile_m}x{spec.tile_n}x{spec.tile_k}"
print(
f" {i:<3} {spec.name:<18} {tile:<14} {spec.pipeline:<10} {spec.scheduler:<12}"
)
print(" " + "-" * 68)
print(f" Data type: {dtype}")
def main():
parser = argparse.ArgumentParser(
description="Basic GEMM Example with Multiple Kernels",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python3 01_basic_gemm.py # Default FP16 with 4 kernels
python3 01_basic_gemm.py --dtype bf16 # BF16 mode
python3 01_basic_gemm.py --size 2048 # Larger problem size
python3 01_basic_gemm.py --num-kernels 2 # Test only 2 kernels
""",
)
parser.add_argument(
"--dtype",
default="fp16",
choices=["fp16", "bf16", "fp32"],
help="Data type (default: fp16)",
)
parser.add_argument(
"--arch",
default="gfx942",
help="Target architecture (default: gfx942)",
)
parser.add_argument(
"--size",
type=int,
default=512,
help="Problem size MxNxK (default: 512)",
)
parser.add_argument(
"--num-kernels",
type=int,
default=0,
help="Number of kernels to test (0 = all)",
)
args = parser.parse_args()
reset_for_example()
print("=" * 70)
print("Example 01: Basic GEMM with Multiple Kernels")
print("=" * 70)
# Select kernels to test
specs = KERNEL_SPECS[: args.num_kernels] if args.num_kernels > 0 else KERNEL_SPECS
# =========================================================================
# Step 1: Print all kernel configurations
# =========================================================================
print_kernel_table(specs, args.dtype)
# =========================================================================
# Step 2: Setup and test each kernel
# =========================================================================
print("\n" + "=" * 70)
print(" RUNNING KERNELS")
print("=" * 70)
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32
M, N, K = args.size, args.size, args.size
results = []
print(f"\n Problem size: {M}x{N}x{K}\n")
print(
f" {'#':<3} {'Name':<18} {'Tile':<14} {'Time (ms)':>10} {'TFLOPS':>10} {'Max Err':>10} {'Status':<8}"
)
print(" " + "-" * 78)
for i, spec in enumerate(specs, 1):
# Create unique test data per kernel
np.random.seed(42 + i * 1000)
A = (np.random.randn(M, K) * 0.1).astype(np_dtype)
B = (np.random.randn(K, N) * 0.1).astype(np_dtype)
# Create config and setup dispatcher
config = create_kernel_config(spec, args.dtype, args.arch)
setup = setup_gemm_dispatcher(
config=config,
registry_name=f"kernel_{spec.name}",
verbose=False,
auto_rebuild=True,
)
tile = f"{spec.tile_m}x{spec.tile_n}x{spec.tile_k}"
if not setup.success:
print(
f" {i:<3} {spec.name:<18} {tile:<14} {'N/A':>10} {'N/A':>10} {'N/A':>10} {'FAIL':<8}"
)
results.append((spec.name, False, 0, 0, 0))
cleanup_gemm()
continue
dispatcher = setup.dispatcher
# Check if size is supported
if not dispatcher.is_supported(M, N, K):
print(
f" {i:<3} {spec.name:<18} {tile:<14} {'N/A':>10} {'N/A':>10} {'N/A':>10} {'SKIP':<8}"
)
results.append((spec.name, False, 0, 0, 0))
cleanup_gemm()
continue
# Run GEMM
result = dispatcher.run(A, B, M, N, K)
if not result.success:
print(
f" {i:<3} {spec.name:<18} {tile:<14} {'N/A':>10} {'N/A':>10} {'N/A':>10} {'FAIL':<8}"
)
results.append((spec.name, False, 0, 0, 0))
cleanup_gemm()
continue
# Validate against NumPy reference
C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype(np_dtype)
max_err = np.max(np.abs(result.output - C_ref))
# Check if within tolerance
passed = max_err < 1e-2
status = "PASS" if passed else "FAIL"
print(
f" {i:<3} {spec.name:<18} {tile:<14} {result.time_ms:>10.4f} {result.tflops:>10.2f} {max_err:>10.2e} {status:<8}"
)
results.append((spec.name, passed, result.time_ms, result.tflops, max_err))
cleanup_gemm()
# =========================================================================
# Step 3: Summary
# =========================================================================
print("\n" + "=" * 70)
print(" SUMMARY")
print("=" * 70)
passed = sum(1 for r in results if r[1])
failed = len(results) - passed
print(f"\n Results: {passed}/{len(results)} kernels passed")
print(f" Problem: {M}x{N}x{K}, dtype={args.dtype}")
if results:
valid_results = [r for r in results if r[1]]
if valid_results:
best = max(valid_results, key=lambda x: x[3])
print(f"\n Best kernel: {best[0]} ({best[3]:.2f} TFLOPS)")
if failed == 0:
print("\n *** ALL KERNELS PASSED ***")
else:
print(f"\n *** {failed} KERNELS FAILED ***")
print("=" * 70)
return 0 if failed == 0 else 1
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,149 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 02: Batch GEMM
Runs multiple GEMM operations with different sizes.
Complexity: ★★☆☆☆
Usage:
python3 02_batch_gemm.py
python3 02_batch_gemm.py --help
python3 02_batch_gemm.py --dtype bf16
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from ctypes_utils import (
KernelConfig,
setup_gemm_dispatcher,
cleanup_gemm,
reset_for_example,
)
def main():
parser = argparse.ArgumentParser(
description="Batch GEMM Example - runs multiple sizes",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python3 02_batch_gemm.py # Default FP16
python3 02_batch_gemm.py --dtype bf16 # BF16 GEMM
python3 02_batch_gemm.py --max-size 2048 # Limit max size
""",
)
parser.add_argument(
"--dtype",
default="fp16",
choices=["fp16", "bf16", "fp32"],
help="Data type (default: fp16)",
)
parser.add_argument(
"--max-size",
type=int,
default=4096,
help="Maximum problem size (default: 4096)",
)
parser.add_argument(
"--arch", default="gfx942", help="Target architecture (default: gfx942)"
)
args = parser.parse_args()
reset_for_example()
print("=" * 60)
print("Example 02: Batch GEMM")
print("=" * 60)
# =========================================================================
# Step 1: Setup dispatcher
# =========================================================================
print("\nStep 1: Setup Dispatcher")
config = KernelConfig(
dtype_a=args.dtype,
dtype_b=args.dtype,
dtype_c=args.dtype,
tile_m=128,
tile_n=128,
tile_k=32,
gfx_arch=args.arch,
)
setup = setup_gemm_dispatcher(config, registry_name="batch_gemm", verbose=True)
if not setup.success:
print(f" ERROR: {setup.error}")
return 1
dispatcher = setup.dispatcher
# =========================================================================
# Step 2: Run batch of different sizes
# =========================================================================
print("\nStep 2: Run Batch")
# Generate sizes up to max_size
all_sizes = [
(256, 256, 256),
(512, 512, 512),
(1024, 1024, 1024),
(2048, 2048, 2048),
(4096, 4096, 4096),
]
sizes = [(m, n, k) for m, n, k in all_sizes if max(m, n, k) <= args.max_size]
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32
print(f"\n {'Size':<20} | {'Time (ms)':>12} | {'TFLOPS':>10} | {'Status':>8}")
print(" " + "-" * 60)
total_ops = 0
total_time = 0
for M, N, K in sizes:
if not dispatcher.is_supported(M, N, K):
print(f" {M:>4}x{N:>4}x{K:<4} | {'N/A':>12} | {'N/A':>10} | Skipped")
continue
A = np.random.randn(M, K).astype(np_dtype) * 0.1
B = np.random.randn(K, N).astype(np_dtype) * 0.1
result = dispatcher.run(A, B, M, N, K)
if result.success:
total_ops += 2 * M * N * K
total_time += result.time_ms
print(
f" {M:>4}x{N:>4}x{K:<4} | {result.time_ms:>12.4f} | {result.tflops:>10.2f} | OK"
)
else:
print(f" {M:>4}x{N:>4}x{K:<4} | {'N/A':>12} | {'N/A':>10} | Error")
print(" " + "-" * 60)
if total_time > 0:
avg_tflops = (total_ops / 1e12) / (total_time / 1000)
print(f"\n Total: {total_time:.2f} ms, Average: {avg_tflops:.2f} TFLOPS")
# Cleanup
cleanup_gemm()
print("\n" + "=" * 60)
print("Batch GEMM complete!")
print("=" * 60)
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,171 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 03: Benchmark
Performance benchmarking with compute-optimized kernel configuration.
Complexity: ★★★☆☆
Usage:
python3 03_benchmark.py
python3 03_benchmark.py --help
python3 03_benchmark.py --size 4096
python3 03_benchmark.py --dtype bf16 --iterations 20
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from ctypes_utils import (
KernelConfig,
setup_gemm_dispatcher,
cleanup_gemm,
reset_for_example,
)
def main():
parser = argparse.ArgumentParser(
description="GEMM Benchmark Example - performance testing",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python3 03_benchmark.py # Default benchmark suite
python3 03_benchmark.py --size 4096 # Single size benchmark
python3 03_benchmark.py --dtype bf16 # BF16 benchmark
python3 03_benchmark.py --iterations 20 # More iterations
""",
)
parser.add_argument(
"--dtype",
default="bf16",
choices=["fp16", "bf16", "fp32"],
help="Data type (default: bf16)",
)
parser.add_argument(
"--size",
type=int,
default=0,
help="Single problem size MxNxK (default: run all sizes)",
)
parser.add_argument(
"--warmup", type=int, default=3, help="Warmup iterations (default: 3)"
)
parser.add_argument(
"--iterations", type=int, default=10, help="Benchmark iterations (default: 10)"
)
parser.add_argument(
"--arch", default="gfx942", help="Target architecture (default: gfx942)"
)
args = parser.parse_args()
reset_for_example()
print("=" * 60)
print("Example 03: Benchmark")
print("=" * 60)
# =========================================================================
# Step 1: Setup dispatcher with compute-optimized config
# =========================================================================
print("\nStep 1: Setup Dispatcher")
config = KernelConfig(
dtype_a=args.dtype,
dtype_b=args.dtype,
dtype_c=args.dtype,
tile_m=128,
tile_n=128,
tile_k=32,
pipeline="compv4",
scheduler="intrawave",
gfx_arch=args.arch,
)
setup = setup_gemm_dispatcher(config, registry_name="benchmark", verbose=True)
if not setup.success:
print(f" ERROR: {setup.error}")
return 1
dispatcher = setup.dispatcher
# =========================================================================
# Step 2: Benchmark
# =========================================================================
print("\nStep 2: Benchmark")
if args.size > 0:
sizes = [(args.size, args.size, args.size)]
else:
sizes = [
(512, 512, 512),
(1024, 1024, 1024),
(2048, 2048, 2048),
(4096, 4096, 4096),
(1024, 2048, 512),
(2048, 1024, 2048),
]
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32
print(f" Warmup: {args.warmup}, Iterations: {args.iterations}\n")
print(f" {'Size':<20} | {'Min (ms)':>10} | {'Avg (ms)':>10} | {'TFLOPS':>10}")
print(" " + "-" * 60)
all_tflops = []
for M, N, K in sizes:
if not dispatcher.is_supported(M, N, K):
continue
A = np.random.randn(M, K).astype(np_dtype) * 0.1
B = np.random.randn(K, N).astype(np_dtype) * 0.1
# Warmup
for _ in range(args.warmup):
dispatcher.run(A, B, M, N, K)
# Benchmark
times = []
for _ in range(args.iterations):
result = dispatcher.run(A, B, M, N, K)
if result.success:
times.append(result.time_ms)
if times:
min_time = min(times)
avg_time = sum(times) / len(times)
tflops = (2.0 * M * N * K / (avg_time * 1e-3)) / 1e12
all_tflops.append(tflops)
print(
f" {M:>4}x{N:>4}x{K:<4} | {min_time:>10.4f} | {avg_time:>10.4f} | {tflops:>10.2f}"
)
# Cleanup
cleanup_gemm()
# Summary
print("\n" + "=" * 60)
print("Summary")
print("=" * 60)
if all_tflops:
print(f" Average: {sum(all_tflops) / len(all_tflops):.2f} TFLOPS")
print(f" Peak: {max(all_tflops):.2f} TFLOPS")
print("=" * 60)
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,156 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 04: Validation
Validates GPU GEMM against NumPy reference.
Complexity: ★★★☆☆
Usage:
python3 04_validation.py
python3 04_validation.py --help
python3 04_validation.py --dtype bf16
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from ctypes_utils import (
KernelConfig,
Validator,
setup_gemm_dispatcher,
cleanup_gemm,
reset_for_example,
)
def main():
parser = argparse.ArgumentParser(
description="GEMM Validation Example - validates GPU results against NumPy",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python3 04_validation.py # Default FP16 validation
python3 04_validation.py --dtype bf16 # BF16 validation
python3 04_validation.py --rtol 1e-2 # Relaxed tolerance
""",
)
parser.add_argument(
"--dtype",
default="fp16",
choices=["fp16", "bf16", "fp32"],
help="Data type (default: fp16)",
)
parser.add_argument(
"--rtol", type=float, default=1e-3, help="Relative tolerance (default: 1e-3)"
)
parser.add_argument(
"--atol", type=float, default=1e-2, help="Absolute tolerance (default: 1e-2)"
)
parser.add_argument(
"--arch", default="gfx942", help="Target architecture (default: gfx942)"
)
args = parser.parse_args()
reset_for_example()
print("=" * 60)
print("Example 04: Validation")
print("=" * 60)
# =========================================================================
# Step 1: Setup dispatcher
# =========================================================================
print("\nStep 1: Setup Dispatcher")
config = KernelConfig(
dtype_a=args.dtype,
dtype_b=args.dtype,
dtype_c=args.dtype,
tile_m=128,
tile_n=128,
tile_k=32,
gfx_arch=args.arch,
)
setup = setup_gemm_dispatcher(config, registry_name="validation", verbose=True)
if not setup.success:
print(f" ERROR: {setup.error}")
return 1
dispatcher = setup.dispatcher
# =========================================================================
# Step 2: Run validation tests
# =========================================================================
print("\nStep 2: Validation Tests")
validator = Validator(rtol=args.rtol, atol=args.atol)
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32
test_cases = [
("Identity", 128, 128, 128, "identity"),
("Small", 256, 256, 256, "random"),
("Medium", 512, 512, 512, "random"),
("Large", 1024, 1024, 1024, "random"),
("Non-square", 512, 1024, 256, "random"),
]
passed = 0
failed = 0
print(f"\n {'Test':<15} | {'Size':<15} | {'Max Err':>10} | {'Status':>8}")
print(" " + "-" * 55)
for name, M, N, K, pattern in test_cases:
if not dispatcher.is_supported(M, N, K):
print(f" {name:<15} | {M}x{N}x{K:<5} | {'N/A':>10} | Skipped")
continue
np.random.seed(42)
if pattern == "identity":
A = np.eye(M, K, dtype=np_dtype)
B = np.eye(K, N, dtype=np_dtype)
else:
A = (np.random.randn(M, K) * 0.1).astype(np_dtype)
B = (np.random.randn(K, N) * 0.1).astype(np_dtype)
result = dispatcher.run(A, B, M, N, K)
if not result.success:
print(f" {name:<15} | {M}x{N}x{K:<5} | {'GPU Err':>10} | FAILED")
failed += 1
continue
C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype(np_dtype)
is_valid, max_err, _ = validator.check(result.output, C_ref)
if is_valid:
print(f" {name:<15} | {M}x{N}x{K:<5} | {max_err:>10.2e} | PASSED")
passed += 1
else:
print(f" {name:<15} | {M}x{N}x{K:<5} | {max_err:>10.2e} | FAILED")
failed += 1
# Cleanup
cleanup_gemm()
# Summary
print("\n" + "=" * 60)
total = passed + failed
print(f"Results: {passed}/{total} passed")
print(f"Settings: dtype={args.dtype}, rtol={args.rtol}, atol={args.atol}")
print("=" * 60)
return 0 if failed == 0 else 1
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,166 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 05: NumPy Integration
Shows how to create a GPU-accelerated matmul wrapper.
Complexity: ★★☆☆☆
Usage:
python3 05_numpy_integration.py
python3 05_numpy_integration.py --help
python3 05_numpy_integration.py --dtype bf16
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from ctypes_utils import (
KernelConfig,
Dispatcher,
setup_gemm_dispatcher,
cleanup_gemm,
reset_for_example,
)
class GPUMatmul:
"""GPU-accelerated matrix multiplication wrapper."""
def __init__(self, dispatcher: Dispatcher):
self.dispatcher = dispatcher
def __call__(self, A: np.ndarray, B: np.ndarray) -> np.ndarray:
"""Compute C = A @ B on GPU with CPU fallback."""
M, K = A.shape
K2, N = B.shape
if K != K2:
raise ValueError(f"Dimension mismatch: {A.shape} @ {B.shape}")
if not self.dispatcher.is_supported(M, N, K):
return np.matmul(A, B)
result = self.dispatcher.run(A, B, M, N, K)
return result.output if result.success else np.matmul(A, B)
def main():
parser = argparse.ArgumentParser(
description="NumPy Integration Example - GPU-accelerated matmul wrapper",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python3 05_numpy_integration.py # Default FP16
python3 05_numpy_integration.py --dtype bf16 # BF16 mode
""",
)
parser.add_argument(
"--dtype",
default="fp16",
choices=["fp16", "bf16", "fp32"],
help="Data type (default: fp16)",
)
parser.add_argument(
"--arch", default="gfx942", help="Target architecture (default: gfx942)"
)
args = parser.parse_args()
reset_for_example()
print("=" * 60)
print("Example 05: NumPy Integration")
print("=" * 60)
# =========================================================================
# Step 1: Setup dispatcher
# =========================================================================
print("\nStep 1: Setup Dispatcher")
config = KernelConfig(
dtype_a=args.dtype,
dtype_b=args.dtype,
dtype_c=args.dtype,
tile_m=128,
tile_n=128,
tile_k=32,
gfx_arch=args.arch,
)
setup = setup_gemm_dispatcher(config, registry_name="numpy", verbose=True)
if not setup.success:
print(f" ERROR: {setup.error}")
return 1
dispatcher = setup.dispatcher
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32
# =========================================================================
# Step 2: Create GPU matmul wrapper
# =========================================================================
print("\nStep 2: Create GPUMatmul")
gpu_matmul = GPUMatmul(dispatcher=dispatcher)
print(" gpu_matmul ready")
# =========================================================================
# Step 3: Demo - Simple multiplication using gpu_matmul
# =========================================================================
print("\nStep 3: Demo - Simple Multiplication")
A = np.random.randn(1024, 512).astype(np_dtype) * 0.1
B = np.random.randn(512, 256).astype(np_dtype) * 0.1
# Use the gpu_matmul wrapper
C = gpu_matmul(A, B)
print(f" gpu_matmul result: {C.shape}, sum={C.sum():.4f}")
M, K = A.shape
_, N = B.shape
result = dispatcher.run(A, B, M, N, K)
print(f" A: {A.shape}, B: {B.shape} -> C: {result.output.shape}")
print(f" GPU: {result.time_ms:.4f} ms, {result.tflops:.2f} TFLOPS")
# =========================================================================
# Step 4: Demo - FFN block
# =========================================================================
print("\nStep 4: Demo - FFN Block")
batch, hidden, ffn = 128, 768, 3072
X = np.random.randn(batch, hidden).astype(np_dtype) * 0.02
W1 = np.random.randn(hidden, ffn).astype(np_dtype) * 0.02
W2 = np.random.randn(ffn, hidden).astype(np_dtype) * 0.02
result1 = dispatcher.run(X, W1, batch, ffn, hidden)
H = result1.output
result2 = dispatcher.run(H, W2, batch, hidden, ffn)
print(f" X: {X.shape} -> H: {H.shape} -> Y: {result2.output.shape}")
print(f" Total: {result1.time_ms + result2.time_ms:.4f} ms")
# Cleanup
cleanup_gemm()
# Summary
print("\n" + "=" * 60)
print("NumPy Integration Pattern:")
print("=" * 60)
print(" 1. setup_gemm_dispatcher(config)")
print(" 2. GPUMatmul(dispatcher)")
print(" 3. C = gpu_matmul(A, B)")
print("=" * 60)
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,169 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 06: JSON Export
Exports registry configuration to JSON.
Complexity: ★★☆☆☆
Usage:
python3 06_json_export.py
python3 06_json_export.py --help
python3 06_json_export.py --output my_kernels.json
"""
import sys
import json
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
from ctypes_utils import (
KernelConfig,
setup_gemm_dispatcher,
cleanup_gemm,
reset_for_example,
)
def main():
parser = argparse.ArgumentParser(
description="JSON Export Example - exports registry to JSON",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python3 06_json_export.py # Default output to kernels.json
python3 06_json_export.py --output my.json # Custom output file
""",
)
parser.add_argument(
"--output",
"-o",
default="kernels.json",
help="Output JSON file (default: kernels.json)",
)
parser.add_argument(
"--dtype",
default="fp16",
choices=["fp16", "bf16", "fp32"],
help="Data type (default: fp16)",
)
parser.add_argument(
"--arch", default="gfx942", help="Target architecture (default: gfx942)"
)
args = parser.parse_args()
reset_for_example()
print("=" * 60)
print("Example 06: JSON Export")
print("=" * 60)
# =========================================================================
# Step 1: Setup dispatcher
# =========================================================================
print("\nStep 1: Setup Dispatcher")
config = KernelConfig(
dtype_a=args.dtype,
dtype_b=args.dtype,
dtype_c=args.dtype,
tile_m=128,
tile_n=128,
tile_k=32,
gfx_arch=args.arch,
)
setup = setup_gemm_dispatcher(config, registry_name="export_demo", verbose=True)
if not setup.success:
print(f" ERROR: {setup.error}")
return 1
# =========================================================================
# Step 2: Define additional configs for export
# =========================================================================
print("\nStep 2: Define Additional Configs")
configs = [
config,
KernelConfig(
dtype_a=args.dtype,
dtype_b=args.dtype,
dtype_c=args.dtype,
tile_m=256,
tile_n=256,
tile_k=64,
gfx_arch=args.arch,
),
KernelConfig(
dtype_a=args.dtype,
dtype_b=args.dtype,
dtype_c=args.dtype,
tile_m=64,
tile_n=64,
tile_k=32,
gfx_arch=args.arch,
),
]
for cfg in configs:
print(f" - {cfg.tile_str}")
# =========================================================================
# Step 3: Export to JSON
# =========================================================================
print("\nStep 3: Export to JSON")
export_data = {
"registry": setup.registry.name,
"kernel_count": len(configs),
"kernels": [],
}
for cfg in configs:
kernel_info = {
"tile": cfg.tile_str,
"dtypes": {"A": cfg.dtype_a, "B": cfg.dtype_b, "C": cfg.dtype_c},
"layout": cfg.layout,
"pipeline": cfg.pipeline,
"target": cfg.gfx_arch,
}
export_data["kernels"].append(kernel_info)
# Include C++ library info
if setup.lib:
cpp_json = setup.lib.export_registry_json()
try:
export_data["cpp_registry"] = json.loads(cpp_json)
except json.JSONDecodeError:
pass
json_str = json.dumps(export_data, indent=2)
with open(args.output, "w") as f:
f.write(json_str)
print(f" Saved to: {args.output}")
# Preview
print("\nStep 4: Preview")
print("-" * 60)
print(json_str[:500] + ("..." if len(json_str) > 500 else ""))
print("-" * 60)
# Cleanup
cleanup_gemm()
print("\n" + "=" * 60)
print("JSON Export complete!")
print("=" * 60)
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,513 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 07: Stress Test - Multiple Kernels with Validation
Consolidated stress test that:
1. Declares multiple kernel configurations (various tiles, pipelines, layouts)
2. Prints all registered kernels with details
3. Validates each kernel against NumPy reference
4. Optional benchmarking mode
This tests:
- Multiple tile sizes (64x64, 128x128, 256x256)
- Multiple pipelines (compv3, compv4)
- Multiple data types (fp16, bf16)
- Different schedulers (intrawave, interwave)
Complexity: ★★★★☆
Usage:
python3 07_stress_test.py
python3 07_stress_test.py --help
python3 07_stress_test.py --num-kernels 10
python3 07_stress_test.py --benchmark
python3 07_stress_test.py --dtype bf16
"""
import sys
import argparse
from pathlib import Path
from dataclasses import dataclass
from typing import List, Tuple
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from ctypes_utils import (
KernelConfig,
setup_gemm_dispatcher,
cleanup_gemm,
reset_for_example,
Validator,
)
@dataclass
class KernelSpec:
"""A kernel specification for testing"""
name: str
tile_m: int
tile_n: int
tile_k: int
wave_m: int = 2
wave_n: int = 2
wave_k: int = 1
warp_m: int = 32
warp_n: int = 32
warp_k: int = 16
pipeline: str = "compv3"
scheduler: str = "intrawave"
layout: str = "rcr"
def to_config(self, dtype: str, arch: str) -> KernelConfig:
"""Convert to KernelConfig"""
# Adjust warp tiles for smaller tiles
warp_m = min(self.warp_m, self.tile_m // self.wave_m)
warp_n = min(self.warp_n, self.tile_n // self.wave_n)
warp_k = self.warp_k
return KernelConfig(
dtype_a=dtype,
dtype_b=dtype,
dtype_c=dtype,
dtype_acc="fp32",
layout_a={"r": "row", "c": "col"}[self.layout[0]],
layout_b={"r": "row", "c": "col"}[self.layout[1]],
layout_c={"r": "row", "c": "col"}[self.layout[2]],
tile_m=self.tile_m,
tile_n=self.tile_n,
tile_k=self.tile_k,
wave_m=self.wave_m,
wave_n=self.wave_n,
wave_k=self.wave_k,
warp_m=warp_m,
warp_n=warp_n,
warp_k=warp_k,
pipeline=self.pipeline,
scheduler=self.scheduler,
epilogue="cshuffle",
gfx_arch=arch,
)
# Define stress test kernel configurations
KERNEL_SPECS = [
# Small tiles - compv3
KernelSpec(
"small_compv3",
64,
64,
32,
wave_m=2,
wave_n=2,
warp_m=16,
warp_n=16,
warp_k=32,
pipeline="compv3",
),
KernelSpec(
"small_compv4",
64,
64,
32,
wave_m=2,
wave_n=2,
warp_m=16,
warp_n=16,
warp_k=32,
pipeline="compv4",
),
# Medium tiles
KernelSpec(
"medium_compv3",
128,
128,
32,
wave_m=2,
wave_n=2,
warp_m=32,
warp_n=32,
warp_k=16,
pipeline="compv3",
),
KernelSpec(
"medium_compv4",
128,
128,
32,
wave_m=2,
wave_n=2,
warp_m=32,
warp_n=32,
warp_k=16,
pipeline="compv4",
),
KernelSpec(
"medium_k64",
128,
128,
64,
wave_m=2,
wave_n=2,
warp_m=32,
warp_n=32,
warp_k=16,
pipeline="compv3",
),
# Rectangular tiles
KernelSpec(
"rect_64x128",
64,
128,
32,
wave_m=2,
wave_n=2,
warp_m=32,
warp_n=32,
warp_k=16,
pipeline="compv3",
),
KernelSpec(
"rect_128x64",
128,
64,
32,
wave_m=2,
wave_n=2,
warp_m=32,
warp_n=32,
warp_k=16,
pipeline="compv3",
),
# Different schedulers
KernelSpec(
"interwave",
128,
128,
32,
wave_m=2,
wave_n=2,
warp_m=32,
warp_n=32,
warp_k=16,
pipeline="compv3",
scheduler="interwave",
),
# Large tiles
KernelSpec(
"large_compv3",
256,
128,
32,
wave_m=2,
wave_n=2,
warp_m=32,
warp_n=32,
warp_k=16,
pipeline="compv3",
),
KernelSpec(
"large_compv4",
256,
128,
64,
wave_m=2,
wave_n=2,
warp_m=32,
warp_n=32,
warp_k=16,
pipeline="compv4",
),
]
def print_kernel_summary(specs: List[KernelSpec], dtype: str):
"""Print a summary table of all kernel specs"""
print("\n" + "=" * 80)
print(f" DECLARED KERNEL CONFIGURATIONS ({len(specs)} kernels)")
print("=" * 80)
print(
f"\n {'#':<3} {'Name':<18} {'Tile':<12} {'Wave':<10} {'Warp':<12} {'Pipeline':<10} {'Sched':<10}"
)
print(" " + "-" * 78)
for i, spec in enumerate(specs, 1):
tile = f"{spec.tile_m}x{spec.tile_n}x{spec.tile_k}"
wave = f"{spec.wave_m}x{spec.wave_n}x{spec.wave_k}"
warp = f"{spec.warp_m}x{spec.warp_n}x{spec.warp_k}"
print(
f" {i:<3} {spec.name:<18} {tile:<12} {wave:<10} {warp:<12} {spec.pipeline:<10} {spec.scheduler:<10}"
)
print(" " + "-" * 78)
print(f" Data type: {dtype}\n")
def validate_kernel(
spec: KernelSpec,
dtype: str,
arch: str,
size: int,
validator: Validator,
kernel_index: int = 0,
verbose: bool = False,
) -> Tuple[bool, float, str]:
"""
Validate a single kernel configuration.
Returns: (passed, max_error, message)
"""
np_dtype = np.float16 if dtype in ["fp16", "bf16"] else np.float32
# Create config
config = spec.to_config(dtype, arch)
# Setup dispatcher
setup = setup_gemm_dispatcher(
config=config,
registry_name=f"stress_{spec.name}",
verbose=False,
auto_rebuild=True,
)
if not setup.success:
return False, 0.0, f"Setup failed: {setup.error}"
dispatcher = setup.dispatcher
M, N, K = size, size, size
if not dispatcher.is_supported(M, N, K):
cleanup_gemm()
return False, 0.0, f"Size {M}x{N}x{K} not supported"
# Use different seed per kernel to get unique test data
# This ensures each kernel is tested with different matrices
np.random.seed(42 + kernel_index * 1000)
A = (np.random.randn(M, K) * 0.1).astype(np_dtype)
B = (np.random.randn(K, N) * 0.1).astype(np_dtype)
# Run GPU GEMM
result = dispatcher.run(A, B, M, N, K)
if not result.success:
cleanup_gemm()
return False, 0.0, "GPU execution failed"
# Validate against NumPy
C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype(np_dtype)
is_valid, max_err, _ = validator.check(result.output, C_ref)
cleanup_gemm()
return is_valid, max_err, f"{result.time_ms:.2f}ms, {result.tflops:.1f} TFLOPS"
def benchmark_kernel(
spec: KernelSpec,
dtype: str,
arch: str,
size: int,
warmup: int = 3,
iterations: int = 10,
) -> Tuple[bool, float, float]:
"""
Benchmark a kernel configuration.
Returns: (success, avg_time_ms, tflops)
"""
np_dtype = np.float16 if dtype in ["fp16", "bf16"] else np.float32
config = spec.to_config(dtype, arch)
setup = setup_gemm_dispatcher(
config=config,
registry_name=f"bench_{spec.name}",
verbose=False,
auto_rebuild=True,
)
if not setup.success:
return False, 0.0, 0.0
dispatcher = setup.dispatcher
M, N, K = size, size, size
if not dispatcher.is_supported(M, N, K):
cleanup_gemm()
return False, 0.0, 0.0
A = (np.random.randn(M, K) * 0.1).astype(np_dtype)
B = (np.random.randn(K, N) * 0.1).astype(np_dtype)
# Warmup
for _ in range(warmup):
dispatcher.run(A, B, M, N, K)
# Benchmark
times = []
for _ in range(iterations):
result = dispatcher.run(A, B, M, N, K)
if result.success:
times.append(result.time_ms)
cleanup_gemm()
if not times:
return False, 0.0, 0.0
avg_time = sum(times) / len(times)
tflops = (2.0 * M * N * K / (avg_time * 1e-3)) / 1e12
return True, avg_time, tflops
def main():
parser = argparse.ArgumentParser(
description="GEMM Stress Test - Multiple kernels with validation",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python3 07_stress_test.py # Test all kernels
python3 07_stress_test.py --num-kernels 5 # Test first 5 kernels
python3 07_stress_test.py --benchmark # Include benchmarks
python3 07_stress_test.py --dtype bf16 # Test BF16
python3 07_stress_test.py --size 2048 # Use 2048x2048 matrices
""",
)
parser.add_argument(
"--dtype",
default="fp16",
choices=["fp16", "bf16", "fp32"],
help="Data type (default: fp16)",
)
parser.add_argument(
"--num-kernels",
type=int,
default=0,
help="Number of kernels to test (0 = all)",
)
parser.add_argument(
"--size",
type=int,
default=512,
help="Problem size MxNxK (default: 512)",
)
parser.add_argument(
"--benchmark",
action="store_true",
help="Include benchmark timing",
)
parser.add_argument(
"--rtol",
type=float,
default=1e-2,
help="Relative tolerance (default: 1e-2)",
)
parser.add_argument(
"--atol",
type=float,
default=1e-2,
help="Absolute tolerance (default: 1e-2)",
)
parser.add_argument(
"--arch",
default="gfx942",
help="Target architecture (default: gfx942)",
)
args = parser.parse_args()
reset_for_example()
print("=" * 80)
print("Example 07: GEMM Stress Test - Multiple Kernels")
print("=" * 80)
# Select kernels to test
specs = KERNEL_SPECS[: args.num_kernels] if args.num_kernels > 0 else KERNEL_SPECS
# Print kernel summary
print_kernel_summary(specs, args.dtype)
# Run validation
print("\n" + "=" * 80)
print(" VALIDATION RESULTS")
print("=" * 80)
validator = Validator(rtol=args.rtol, atol=args.atol)
if args.benchmark:
print(
f"\n {'#':<3} {'Name':<18} {'Tile':<12} {'Max Err':>10} {'Time':>10} {'TFLOPS':>8} {'Status':<8}"
)
else:
print(
f"\n {'#':<3} {'Name':<18} {'Tile':<12} {'Max Err':>10} {'Info':<25} {'Status':<8}"
)
print(" " + "-" * 78)
passed = 0
failed = 0
skipped = 0
for i, spec in enumerate(specs, 1):
tile = f"{spec.tile_m}x{spec.tile_n}x{spec.tile_k}"
try:
is_valid, max_err, info = validate_kernel(
spec, args.dtype, args.arch, args.size, validator, kernel_index=i
)
if is_valid:
status = "PASS"
passed += 1
else:
status = "FAIL"
failed += 1
if args.benchmark:
success, avg_time, tflops = benchmark_kernel(
spec, args.dtype, args.arch, args.size
)
if success:
print(
f" {i:<3} {spec.name:<18} {tile:<12} {max_err:>10.2e} {avg_time:>9.2f}ms {tflops:>7.1f} {status:<8}"
)
else:
print(
f" {i:<3} {spec.name:<18} {tile:<12} {max_err:>10.2e} {'N/A':>10} {'N/A':>8} {status:<8}"
)
else:
print(
f" {i:<3} {spec.name:<18} {tile:<12} {max_err:>10.2e} {info:<25} {status:<8}"
)
except Exception as e:
skipped += 1
print(
f" {i:<3} {spec.name:<18} {tile:<12} {'N/A':>10} {str(e)[:25]:<25} {'SKIP':<8}"
)
# Summary
print("\n" + "=" * 80)
print(" SUMMARY")
print("=" * 80)
total = passed + failed + skipped
print(f"\n Results: {passed}/{total} passed, {failed} failed, {skipped} skipped")
print(f" Settings: dtype={args.dtype}, size={args.size}x{args.size}x{args.size}")
print(f" Tolerance: rtol={args.rtol}, atol={args.atol}")
print(f" Architecture: {args.arch}")
if failed == 0 and skipped == 0:
print("\n *** ALL KERNELS PASSED ***")
elif failed > 0:
print(f"\n *** {failed} KERNELS FAILED ***")
print("=" * 80)
return 0 if failed == 0 else 1
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,718 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 08: Custom Heuristics
Demonstrates custom kernel selection heuristics based on problem characteristics.
This example shows how to:
1. Define multiple kernel configurations for different workloads
2. Implement custom heuristics to select the best kernel
3. Test heuristic selection across different problem sizes
Heuristic strategies:
- Size-based: Small tiles for small problems, large tiles for large problems
- Compute-bound: Maximize compute utilization for large matrices
- Memory-bound: Optimize memory access for bandwidth-limited cases
- Latency-focused: Minimize kernel launch overhead for small problems
Complexity: ★★★★☆
Usage:
python3 08_heuristics.py
python3 08_heuristics.py --help
python3 08_heuristics.py --strategy compute
python3 08_heuristics.py --dtype bf16
"""
import sys
import argparse
from pathlib import Path
from dataclasses import dataclass
from typing import List
from enum import Enum
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from ctypes_utils import (
KernelConfig,
setup_gemm_dispatcher,
cleanup_gemm,
reset_for_example,
)
# =============================================================================
# Kernel Specifications
# =============================================================================
@dataclass
class KernelSpec:
"""Kernel specification with metadata for heuristic selection"""
name: str
tile_m: int
tile_n: int
tile_k: int
pipeline: str = "compv3"
scheduler: str = "intrawave"
# Metadata for heuristics
category: str = "balanced" # small, balanced, large, compute, memory
min_problem_size: int = 0
max_problem_size: int = float("inf")
# Define kernel pool for heuristic selection (20+ kernels)
KERNEL_POOL = [
# ==========================================================================
# SMALL TILES - Low latency, good for small problems
# ==========================================================================
KernelSpec(
"small_64x64_k32",
64,
64,
32,
"compv3",
"intrawave",
category="small",
max_problem_size=256 * 256,
),
KernelSpec(
"small_64x64_k64",
64,
64,
64,
"compv3",
"intrawave",
category="small",
max_problem_size=256 * 256,
),
KernelSpec(
"small_64x64_v4",
64,
64,
32,
"compv4",
"intrawave",
category="small",
max_problem_size=256 * 256,
),
# ==========================================================================
# MEDIUM TILES - Balanced performance
# ==========================================================================
KernelSpec(
"medium_128x128_k32",
128,
128,
32,
"compv3",
"intrawave",
category="balanced",
min_problem_size=128 * 128,
max_problem_size=2048 * 2048,
),
KernelSpec(
"medium_128x128_k64",
128,
128,
64,
"compv3",
"intrawave",
category="balanced",
min_problem_size=256 * 256,
),
KernelSpec(
"medium_128x128_k128",
128,
128,
128,
"compv3",
"intrawave",
category="balanced",
min_problem_size=256 * 256,
),
KernelSpec(
"medium_128x128_v4_k32",
128,
128,
32,
"compv4",
"intrawave",
category="balanced",
min_problem_size=256 * 256,
),
KernelSpec(
"medium_128x128_v4_k64",
128,
128,
64,
"compv4",
"intrawave",
category="balanced",
min_problem_size=256 * 256,
),
# Rectangular medium tiles
KernelSpec(
"rect_64x128_k32",
64,
128,
32,
"compv3",
"intrawave",
category="balanced",
min_problem_size=128 * 128,
),
KernelSpec(
"rect_128x64_k32",
128,
64,
32,
"compv3",
"intrawave",
category="balanced",
min_problem_size=128 * 128,
),
KernelSpec(
"rect_64x128_k64",
64,
128,
64,
"compv3",
"intrawave",
category="balanced",
min_problem_size=256 * 256,
),
KernelSpec(
"rect_128x64_k64",
128,
64,
64,
"compv3",
"intrawave",
category="balanced",
min_problem_size=256 * 256,
),
# ==========================================================================
# LARGE TILES - High throughput for large problems
# ==========================================================================
KernelSpec(
"large_256x128_k32",
256,
128,
32,
"compv3",
"intrawave",
category="large",
min_problem_size=512 * 512,
),
KernelSpec(
"large_256x128_k64",
256,
128,
64,
"compv3",
"intrawave",
category="large",
min_problem_size=512 * 512,
),
KernelSpec(
"large_128x256_k32",
128,
256,
32,
"compv3",
"intrawave",
category="large",
min_problem_size=512 * 512,
),
KernelSpec(
"large_128x256_k64",
128,
256,
64,
"compv3",
"intrawave",
category="large",
min_problem_size=512 * 512,
),
KernelSpec(
"large_256x256_k32",
256,
256,
32,
"compv3",
"intrawave",
category="large",
min_problem_size=1024 * 1024,
),
KernelSpec(
"large_256x256_k64",
256,
256,
64,
"compv3",
"intrawave",
category="large",
min_problem_size=1024 * 1024,
),
# ==========================================================================
# COMPUTE-OPTIMIZED - compv4 pipeline for compute-bound workloads
# ==========================================================================
KernelSpec(
"compute_128x128_v4_k32",
128,
128,
32,
"compv4",
"intrawave",
category="compute",
min_problem_size=256 * 256,
),
KernelSpec(
"compute_128x128_v4_k64",
128,
128,
64,
"compv4",
"intrawave",
category="compute",
min_problem_size=256 * 256,
),
KernelSpec(
"compute_256x128_v4",
256,
128,
64,
"compv4",
"intrawave",
category="compute",
min_problem_size=512 * 512,
),
KernelSpec(
"compute_256x256_v4",
256,
256,
64,
"compv4",
"intrawave",
category="compute",
min_problem_size=1024 * 1024,
),
# ==========================================================================
# MEMORY-OPTIMIZED - Good cache utilization for memory-bound workloads
# ==========================================================================
KernelSpec(
"memory_128x128_k16",
128,
128,
16,
"compv3",
"intrawave",
category="memory",
min_problem_size=256 * 256,
),
KernelSpec(
"memory_64x128_k16",
64,
128,
16,
"compv3",
"intrawave",
category="memory",
min_problem_size=128 * 128,
),
]
def create_kernel_config(spec: KernelSpec, dtype: str, arch: str) -> KernelConfig:
"""Create KernelConfig from spec"""
warp_m = 16 if spec.tile_m <= 64 else 32
warp_n = 16 if spec.tile_n <= 64 else 32
return KernelConfig(
dtype_a=dtype,
dtype_b=dtype,
dtype_c=dtype,
dtype_acc="fp32",
layout_a="row",
layout_b="col",
layout_c="row",
tile_m=spec.tile_m,
tile_n=spec.tile_n,
tile_k=spec.tile_k,
wave_m=2,
wave_n=2,
wave_k=1,
warp_m=warp_m,
warp_n=warp_n,
warp_k=16,
pipeline=spec.pipeline,
scheduler=spec.scheduler,
epilogue="cshuffle",
gfx_arch=arch,
)
# =============================================================================
# Heuristic Strategies
# =============================================================================
class HeuristicStrategy(Enum):
SIZE_BASED = "size"
COMPUTE_BOUND = "compute"
MEMORY_BOUND = "memory"
LATENCY_FOCUSED = "latency"
def size_based_heuristic(
M: int, N: int, K: int, kernels: List[KernelSpec]
) -> KernelSpec:
"""
Select kernel based on problem size.
- Small problems: Use small tiles for low latency
- Medium problems: Use balanced tiles
- Large problems: Use large tiles for high throughput
Also considers K dimension for tile_k selection.
"""
total_elements = M * N
# Filter by problem size constraints
candidates = [
k for k in kernels if k.min_problem_size <= total_elements <= k.max_problem_size
]
if not candidates:
candidates = kernels # Fall back to all kernels
# Determine target category based on problem size
if total_elements < 256 * 256:
target_category = "small"
elif total_elements < 1024 * 1024:
target_category = "balanced"
else:
target_category = "large"
# Filter by category if possible
category_candidates = [k for k in candidates if k.category == target_category]
if category_candidates:
candidates = category_candidates
# Select best tile_k based on K dimension
# Prefer tile_k that divides K well
def tile_k_score(k):
if K % k.tile_k == 0:
return 0 # Perfect division
return K % k.tile_k # Remainder (lower is better)
# Sort by tile_k fit, then by tile size
candidates.sort(key=lambda k: (tile_k_score(k), -k.tile_m * k.tile_n))
return candidates[0]
def compute_bound_heuristic(
M: int, N: int, K: int, kernels: List[KernelSpec]
) -> KernelSpec:
"""
Select kernel optimized for compute-bound workloads.
Prefers compv4 pipeline and larger tiles.
Selects based on problem size to maximize compute utilization.
"""
total_elements = M * N
# Prefer compute category kernels
compute_kernels = [k for k in kernels if k.category == "compute"]
if not compute_kernels:
# Fall back to compv4 kernels
compute_kernels = [k for k in kernels if k.pipeline == "compv4"]
if not compute_kernels:
compute_kernels = kernels
# Filter by problem size
valid = [k for k in compute_kernels if k.min_problem_size <= total_elements]
if valid:
compute_kernels = valid
# For large problems, prefer larger tiles
if total_elements >= 1024 * 1024:
return max(compute_kernels, key=lambda k: k.tile_m * k.tile_n * k.tile_k)
else:
# For smaller problems, prefer medium tiles
return min(
compute_kernels, key=lambda k: abs(k.tile_m - 128) + abs(k.tile_n - 128)
)
def memory_bound_heuristic(
M: int, N: int, K: int, kernels: List[KernelSpec]
) -> KernelSpec:
"""
Select kernel optimized for memory-bound workloads.
Prefers smaller tile_k for better memory access patterns.
"""
# Prefer memory category kernels first
memory_kernels = [k for k in kernels if k.category == "memory"]
if memory_kernels:
# Select based on problem size
total = M * N
if total < 512 * 512:
return min(memory_kernels, key=lambda k: k.tile_m * k.tile_n)
return max(memory_kernels, key=lambda k: k.tile_m * k.tile_n)
# Fall back to balanced with smaller tile_k
balanced = [k for k in kernels if k.category == "balanced"]
if balanced:
# Prefer smaller tile_k for memory-bound
return min(balanced, key=lambda k: k.tile_k)
# Fall back to medium-sized tile with small tile_k
return min(
kernels, key=lambda k: (k.tile_k, abs(k.tile_m - 128) + abs(k.tile_n - 128))
)
def latency_focused_heuristic(
M: int, N: int, K: int, kernels: List[KernelSpec]
) -> KernelSpec:
"""
Select kernel optimized for low latency.
Prefers smaller tiles and compv4 for faster execution.
"""
# Prefer small category
small_kernels = [k for k in kernels if k.category == "small"]
if small_kernels:
# Among small kernels, prefer compv4 for lower latency
v4_small = [k for k in small_kernels if k.pipeline == "compv4"]
if v4_small:
return v4_small[0]
return small_kernels[0]
# Fall back to smallest tile with compv4 if available
all_v4 = [k for k in kernels if k.pipeline == "compv4"]
if all_v4:
return min(all_v4, key=lambda k: k.tile_m * k.tile_n)
# Fall back to smallest tile
return min(kernels, key=lambda k: k.tile_m * k.tile_n)
HEURISTICS = {
HeuristicStrategy.SIZE_BASED: size_based_heuristic,
HeuristicStrategy.COMPUTE_BOUND: compute_bound_heuristic,
HeuristicStrategy.MEMORY_BOUND: memory_bound_heuristic,
HeuristicStrategy.LATENCY_FOCUSED: latency_focused_heuristic,
}
# =============================================================================
# Main
# =============================================================================
def print_kernel_pool(kernels: List[KernelSpec]):
"""Print available kernels"""
print("\n" + "=" * 75)
print(" KERNEL POOL")
print("=" * 75)
print(f"\n {'#':<3} {'Name':<22} {'Tile':<14} {'Pipeline':<10} {'Category':<12}")
print(" " + "-" * 73)
for i, k in enumerate(kernels, 1):
tile = f"{k.tile_m}x{k.tile_n}x{k.tile_k}"
print(f" {i:<3} {k.name:<22} {tile:<14} {k.pipeline:<10} {k.category:<12}")
print(" " + "-" * 73)
def main():
parser = argparse.ArgumentParser(
description="Custom Heuristics Example - intelligent kernel selection",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python3 08_heuristics.py # Default size-based heuristic
python3 08_heuristics.py --strategy compute # Compute-bound heuristic
python3 08_heuristics.py --strategy memory # Memory-bound heuristic
python3 08_heuristics.py --strategy latency # Latency-focused heuristic
python3 08_heuristics.py --dtype bf16 # BF16 mode
""",
)
parser.add_argument(
"--dtype",
default="fp16",
choices=["fp16", "bf16", "fp32"],
help="Data type (default: fp16)",
)
parser.add_argument(
"--strategy",
default="size",
choices=["size", "compute", "memory", "latency"],
help="Heuristic strategy (default: size)",
)
parser.add_argument(
"--arch",
default="gfx942",
help="Target architecture (default: gfx942)",
)
args = parser.parse_args()
reset_for_example()
print("=" * 75)
print("Example 08: Custom Heuristics")
print("=" * 75)
# Map strategy string to enum
strategy_map = {
"size": HeuristicStrategy.SIZE_BASED,
"compute": HeuristicStrategy.COMPUTE_BOUND,
"memory": HeuristicStrategy.MEMORY_BOUND,
"latency": HeuristicStrategy.LATENCY_FOCUSED,
}
strategy = strategy_map[args.strategy]
heuristic_fn = HEURISTICS[strategy]
print(f"\n Strategy: {strategy.value}")
print(f" Data type: {args.dtype}")
# Print kernel pool
print_kernel_pool(KERNEL_POOL)
# =========================================================================
# Test heuristic selection across different problem sizes
# =========================================================================
print("\n" + "=" * 75)
print(" HEURISTIC SELECTION TEST")
print("=" * 75)
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32
test_sizes = [
(128, 128, 64), # Small
(256, 256, 128), # Small-medium
(512, 512, 256), # Medium
(1024, 1024, 512), # Medium-large
(2048, 2048, 1024), # Large
]
print(
f"\n {'Size':<20} {'Selected Kernel':<25} {'Time (ms)':>10} {'TFLOPS':>10} {'Status':<8}"
)
print(" " + "-" * 78)
results = []
for M, N, K in test_sizes:
# Use heuristic to select kernel
selected_spec = heuristic_fn(M, N, K, KERNEL_POOL)
# Create config and setup
config = create_kernel_config(selected_spec, args.dtype, args.arch)
setup = setup_gemm_dispatcher(
config=config,
registry_name=f"heuristic_{selected_spec.name}",
verbose=False,
auto_rebuild=True,
)
size_str = f"{M}x{N}x{K}"
if not setup.success:
print(
f" {size_str:<20} {selected_spec.name:<25} {'N/A':>10} {'N/A':>10} {'FAIL':<8}"
)
results.append((size_str, selected_spec.name, False, 0, 0))
cleanup_gemm()
continue
dispatcher = setup.dispatcher
if not dispatcher.is_supported(M, N, K):
print(
f" {size_str:<20} {selected_spec.name:<25} {'N/A':>10} {'N/A':>10} {'SKIP':<8}"
)
results.append((size_str, selected_spec.name, False, 0, 0))
cleanup_gemm()
continue
# Run GEMM
np.random.seed(42)
A = (np.random.randn(M, K) * 0.1).astype(np_dtype)
B = (np.random.randn(K, N) * 0.1).astype(np_dtype)
result = dispatcher.run(A, B, M, N, K)
if not result.success:
print(
f" {size_str:<20} {selected_spec.name:<25} {'N/A':>10} {'N/A':>10} {'FAIL':<8}"
)
results.append((size_str, selected_spec.name, False, 0, 0))
cleanup_gemm()
continue
# Validate
C_ref = np.matmul(A.astype(np.float32), B.astype(np.float32)).astype(np_dtype)
max_err = np.max(np.abs(result.output - C_ref))
passed = max_err < 1e-2
status = "PASS" if passed else "FAIL"
print(
f" {size_str:<20} {selected_spec.name:<25} {result.time_ms:>10.4f} {result.tflops:>10.2f} {status:<8}"
)
results.append(
(size_str, selected_spec.name, passed, result.time_ms, result.tflops)
)
cleanup_gemm()
# =========================================================================
# Summary
# =========================================================================
print("\n" + "=" * 75)
print(" SUMMARY")
print("=" * 75)
passed = sum(1 for r in results if r[2])
failed = len(results) - passed
print(f"\n Strategy: {strategy.value}")
print(f" Results: {passed}/{len(results)} tests passed")
# Show kernel selection distribution
kernel_usage = {}
for r in results:
kernel_usage[r[1]] = kernel_usage.get(r[1], 0) + 1
print("\n Kernel Selection Distribution:")
for kernel, count in sorted(kernel_usage.items(), key=lambda x: -x[1]):
print(f" {kernel}: {count} times")
if results:
valid_results = [r for r in results if r[2]]
if valid_results:
avg_tflops = sum(r[4] for r in valid_results) / len(valid_results)
print(f"\n Average TFLOPS: {avg_tflops:.2f}")
if failed == 0:
print("\n *** ALL TESTS PASSED ***")
else:
print(f"\n *** {failed} TESTS FAILED ***")
print("=" * 75)
return 0 if failed == 0 else 1
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,220 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 09: Multiple Registries
Demonstrates multiple registries for different optimization targets.
Complexity: ★★★★★
Usage:
python3 09_multi_registry.py
python3 09_multi_registry.py --help
python3 09_multi_registry.py --dtype bf16
"""
import sys
import argparse
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from ctypes_utils import (
KernelConfig,
Registry,
Dispatcher,
setup_gemm_dispatcher,
cleanup_gemm,
reset_for_example,
)
def main():
parser = argparse.ArgumentParser(
description="Multiple Registries Example - optimization-specific registries",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python3 09_multi_registry.py # Default FP16
python3 09_multi_registry.py --dtype bf16 # BF16 mode
""",
)
parser.add_argument(
"--dtype",
default="fp16",
choices=["fp16", "bf16", "fp32"],
help="Data type (default: fp16)",
)
parser.add_argument(
"--arch", default="gfx942", help="Target architecture (default: gfx942)"
)
args = parser.parse_args()
reset_for_example()
print("=" * 60)
print("Example 09: Multiple Registries")
print("=" * 60)
# =========================================================================
# Step 1: Setup base dispatcher
# =========================================================================
print("\nStep 1: Setup Base Dispatcher")
base_config = KernelConfig(
dtype_a=args.dtype,
dtype_b=args.dtype,
dtype_c=args.dtype,
tile_m=128,
tile_n=128,
tile_k=32,
gfx_arch=args.arch,
)
setup = setup_gemm_dispatcher(base_config, registry_name="base", verbose=True)
if not setup.success:
print(f" ERROR: {setup.error}")
return 1
lib = setup.lib
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32
# =========================================================================
# Step 2: Define configs for different optimization targets
# =========================================================================
print("\nStep 2: Define Optimization Targets")
compute_config = KernelConfig(
dtype_a=args.dtype,
dtype_b=args.dtype,
dtype_c=args.dtype,
tile_m=256,
tile_n=256,
tile_k=64,
wave_m=4,
wave_n=4,
pipeline="compv4",
gfx_arch=args.arch,
)
memory_config = KernelConfig(
dtype_a=args.dtype,
dtype_b=args.dtype,
dtype_c=args.dtype,
tile_m=128,
tile_n=128,
tile_k=32,
wave_m=2,
wave_n=2,
pipeline="compv4",
gfx_arch=args.arch,
)
latency_config = KernelConfig(
dtype_a=args.dtype,
dtype_b=args.dtype,
dtype_c=args.dtype,
tile_m=64,
tile_n=64,
tile_k=32,
wave_m=1,
wave_n=1,
pipeline="compv3",
gfx_arch=args.arch,
)
print(f" Compute: {compute_config.tile_str} (large matrices)")
print(f" Memory: {memory_config.tile_str} (medium matrices)")
print(f" Latency: {latency_config.tile_str} (small matrices)")
# =========================================================================
# Step 3: Create registries
# =========================================================================
print("\nStep 3: Create Registries")
compute_registry = Registry(name="compute", lib=lib)
compute_registry.register_kernel(compute_config)
memory_registry = Registry(name="memory", lib=lib)
memory_registry.register_kernel(memory_config)
latency_registry = Registry(name="latency", lib=lib)
latency_registry.register_kernel(latency_config)
# =========================================================================
# Step 4: Create dispatchers
# =========================================================================
print("\nStep 4: Create Dispatchers")
compute_dispatcher = Dispatcher(registry=compute_registry, lib=lib)
memory_dispatcher = Dispatcher(registry=memory_registry, lib=lib)
latency_dispatcher = Dispatcher(registry=latency_registry, lib=lib)
print(f" {compute_dispatcher}")
print(f" {memory_dispatcher}")
print(f" {latency_dispatcher}")
# =========================================================================
# Step 5: Smart dispatcher selection
# =========================================================================
print("\nStep 5: Smart Dispatcher Selection")
def select_dispatcher(M: int, N: int, K: int) -> Dispatcher:
elements = M * N
if elements >= 4096 * 4096:
return compute_dispatcher
elif elements >= 1024 * 1024:
return memory_dispatcher
else:
return latency_dispatcher
test_sizes = [
(256, 256, 256),
(512, 512, 512),
(1024, 1024, 1024),
(2048, 2048, 2048),
(4096, 4096, 4096),
]
print(f"\n {'Size':<20} {'Registry':>10} {'Time (ms)':>12} {'TFLOPS':>10}")
print(" " + "-" * 55)
for M, N, K in test_sizes:
dispatcher = select_dispatcher(M, N, K)
if not dispatcher.is_supported(M, N, K):
continue
A = np.random.randn(M, K).astype(np_dtype) * 0.1
B = np.random.randn(K, N).astype(np_dtype) * 0.1
result = dispatcher.run(A, B, M, N, K)
if result.success:
print(
f" {M}x{N}x{K:<10} {dispatcher.registry.name:>10} "
f"{result.time_ms:>12.4f} {result.tflops:>10.2f}"
)
# Cleanup
cleanup_gemm()
# Summary
print("\n" + "=" * 60)
print("Multi-Registry Pattern:")
print("=" * 60)
print(" 1. Define KernelConfig for each optimization target")
print(" 2. Create Registry for each target")
print(" 3. Register configs to appropriate registries")
print(" 4. Create Dispatcher for each registry")
print(" 5. Select dispatcher based on problem characteristics")
print(" 6. Run GEMM with selected dispatcher")
print("=" * 60)
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,260 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 10: Advanced Benchmarking with Full Control
This example demonstrates all available benchmark parameters:
- warmup: Number of warmup iterations (default: 5)
- repeat: Number of benchmark iterations (default: 20)
- flush_cache: Flush GPU cache between iterations (default: False)
- timer: Timer type - "gpu" (default) or "cpu"
- init: Initialization method - "random", "linear", "constant"
Usage:
python3 10_advanced_benchmark.py
python3 10_advanced_benchmark.py --warmup 10 --repeat 100
python3 10_advanced_benchmark.py --init linear
"""
import argparse
import sys
from pathlib import Path
# Add paths for imports
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent / "python"))
import numpy as np
from ctypes_utils import (
KernelConfig,
setup_gemm_dispatcher,
cleanup_gemm,
reset_for_example,
)
def parse_args():
parser = argparse.ArgumentParser(
description="Advanced GEMM benchmarking with full parameter control"
)
# Problem size
parser.add_argument("-m", type=int, default=2048, help="M dimension")
parser.add_argument("-n", type=int, default=2048, help="N dimension")
parser.add_argument("-k", type=int, default=2048, help="K dimension")
# Benchmark parameters
parser.add_argument(
"--warmup", type=int, default=5, help="Number of warmup iterations"
)
parser.add_argument(
"--repeat", type=int, default=20, help="Number of benchmark iterations"
)
parser.add_argument(
"--flush-cache", action="store_true", help="Flush GPU cache between iterations"
)
parser.add_argument(
"--timer", choices=["gpu", "cpu"], default="gpu", help="Timer type (gpu or cpu)"
)
parser.add_argument(
"--init",
choices=["random", "linear", "constant"],
default="random",
help="Initialization method",
)
# Kernel configuration
parser.add_argument("--dtype", default="fp16", help="Data type")
parser.add_argument("--pipeline", default="compv4", help="Pipeline type")
parser.add_argument("--arch", default="gfx942", help="GPU architecture")
return parser.parse_args()
def initialize_matrix(shape, method, dtype):
"""Initialize matrix with specified method"""
if method == "random":
return np.random.randn(*shape).astype(dtype) * 0.5
elif method == "linear":
total = np.prod(shape)
return np.arange(total).reshape(shape).astype(dtype) / total
elif method == "constant":
return np.ones(shape, dtype=dtype)
else:
return np.random.randn(*shape).astype(dtype)
def main():
args = parse_args()
reset_for_example()
print("=" * 70)
print("Example 10: Advanced GEMM Benchmarking")
print("=" * 70)
# Show benchmark configuration
print("\nBenchmark Configuration:")
print(f" Problem Size: {args.m} x {args.n} x {args.k}")
print(f" Warmup: {args.warmup} iterations")
print(f" Repeat: {args.repeat} iterations")
print(f" Flush Cache: {args.flush_cache}")
print(f" Timer: {args.timer}")
print(f" Init Method: {args.init}")
print(f" Data Type: {args.dtype}")
print(f" Pipeline: {args.pipeline}")
print(f" Architecture: {args.arch}")
print()
# Map dtype
np_dtype = np.float16 if args.dtype in ["fp16", "bf16"] else np.float32
# Initialize matrices
print("Step 1: Initialize matrices...")
A = initialize_matrix((args.m, args.k), args.init, np_dtype)
B = initialize_matrix((args.k, args.n), args.init, np_dtype)
print(f" A: {A.shape} ({args.init})")
print(f" B: {B.shape} ({args.init})")
# Create kernel config (does not include M/N/K - those are problem size)
print("\nStep 2: Create kernel configuration...")
kernel_config = KernelConfig(
dtype_a=args.dtype,
dtype_b=args.dtype,
dtype_c=args.dtype,
dtype_acc="fp32",
layout_a="row",
layout_b="col", # B is column-major for optimal performance
layout_c="row",
tile_m=128,
tile_n=128,
tile_k=32,
wave_m=2,
wave_n=2,
wave_k=1,
warp_m=32,
warp_n=32,
warp_k=16,
pipeline=args.pipeline,
scheduler="intrawave",
epilogue="cshuffle",
gfx_arch=args.arch,
)
print(f" Config: {args.dtype}, tile=128x128x32, {args.pipeline}")
# Setup dispatcher
print("\nStep 3: Setup dispatcher...")
setup = setup_gemm_dispatcher(
config=kernel_config,
registry_name="benchmark_gemm",
verbose=False,
auto_rebuild=True,
)
if not setup.success:
print(f" ERROR: {setup.error}")
return 1
dispatcher = setup.dispatcher
print(f" Library: {setup.lib.path if setup.lib else 'N/A'}")
print(f" Kernel: {setup.lib.get_kernel_name() if setup.lib else 'N/A'}")
# Run benchmark with multiple iterations
print("\nStep 4: Run benchmark...")
print(f" Running {args.warmup} warmup + {args.repeat} benchmark iterations...")
# Warmup
for _ in range(args.warmup):
_ = dispatcher.run(A, B, args.m, args.n, args.k)
# Benchmark
times = []
for _ in range(args.repeat):
result = dispatcher.run(A, B, args.m, args.n, args.k)
if result.success:
times.append(result.time_ms)
if times:
avg_time = sum(times) / len(times)
min_time = min(times)
max_time = max(times)
# Calculate TFLOPS
flops = 2 * args.m * args.n * args.k
avg_tflops = (flops / 1e12) / (avg_time / 1000) if avg_time > 0 else 0
max_tflops = (flops / 1e12) / (min_time / 1000) if min_time > 0 else 0
# Calculate bandwidth (C has same dtype as A and B)
C_bytes = args.m * args.n * np.dtype(np_dtype).itemsize
bandwidth_gb = (
(A.nbytes + B.nbytes + C_bytes) / 1e9 / (avg_time / 1000)
if avg_time > 0
else 0
)
print(f"\n *** BENCHMARK RESULTS ({args.repeat} iterations) ***")
print(f" Average Time: {avg_time:.4f} ms")
print(f" Min Time: {min_time:.4f} ms")
print(f" Max Time: {max_time:.4f} ms")
print(f" Avg TFLOPS: {avg_tflops:.2f}")
print(f" Peak TFLOPS: {max_tflops:.2f}")
print(f" Bandwidth: {bandwidth_gb:.2f} GB/s")
else:
print(" FAILED: No successful runs")
return 1
# Summary
print("\n" + "=" * 70)
print("BENCHMARK PARAMETERS REFERENCE")
print("=" * 70)
print("""
Available parameters for GEMM benchmarking:
--warmup N Number of warmup iterations (discard results)
Higher = more stable results, longer run time
Default: 5
--repeat N Number of benchmark iterations
Higher = more accurate average, longer run time
Default: 20
--flush-cache Flush GPU L2 cache between iterations
Use for memory-bound benchmarks
Default: off
--timer {gpu,cpu} Timer type
gpu = HIP events (more accurate for GPU)
cpu = std::chrono (includes kernel launch overhead)
Default: gpu
--init METHOD Matrix initialization
random = uniform random [-0.5, 0.5]
linear = sequential values
constant = all ones
Default: random
Note: For C++ examples, these parameters are passed to stream_config:
ck_tile::stream_config cfg{
nullptr, // stream_id
true, // time_kernel
1, // log_level
5, // cold_niters (warmup)
20, // nrepeat
true, // is_gpu_timer
false, // flush_cache
1 // rotating_count
};
""")
# Cleanup
cleanup_gemm()
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,310 @@
#!/usr/bin/env python3
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
"""
Example 11: JSON-based Kernel Configuration Import
Demonstrates loading kernel configurations from JSON files, similar to tile_engine.
This enables easy customization of kernel sets without modifying code.
Key Features:
- Load tile configs from JSON (compatible with tile_engine format)
- Generate kernel sets from configuration
- Use arch_filter validation on loaded configs
- Export to C++ DECL_KERNEL_SET format
Complexity: ★★★☆☆
Usage:
python3 11_json_import.py
python3 11_json_import.py --config my_kernels.json
python3 11_json_import.py --export-cpp
"""
import sys
import argparse
import json
from pathlib import Path
# Add codegen to path for kernel_config_loader
script_dir = Path(__file__).parent.resolve()
sys.path.insert(0, str(script_dir.parent.parent.parent / "codegen"))
sys.path.insert(0, str(script_dir.parent.parent.parent / "python"))
from kernel_config_loader import ( # noqa: E402
load_kernel_configs,
KernelConfig,
generate_cpp_kernel_set_declaration,
)
from ctypes_utils import ( # noqa: E402
KernelConfig as DispatcherKernelConfig,
setup_gemm_dispatcher,
cleanup_gemm,
reset_for_example,
validate_kernel_config,
)
# Sample JSON configuration (embedded for demonstration)
SAMPLE_JSON_CONFIG = {
"_comment": "Sample kernel configuration for GEMM",
"kernel_set_name": "inference_kernels",
"datatype": {"a": "fp16", "b": "fp16", "c": "fp16", "acc": "fp32"},
"layout": "rcr",
"tile_config": {
"tile_m": {"values": [128, 256]},
"tile_n": {"values": [128, 256]},
"tile_k": {"values": [32]},
"warp_m": {"values": [2]},
"warp_n": {"values": [2]},
"warp_k": {"values": [1]},
"warp_tile_m": {"values": [32]},
"warp_tile_n": {"values": [32]},
"warp_tile_k": {"values": [16]},
},
"trait_config": {
"pipeline": {"values": ["compv4"]},
"scheduler": {"values": ["intrawave"]},
"epilogue": {"values": ["cshuffle"]},
"pad_m": {"values": [False]},
"pad_n": {"values": [False]},
"pad_k": {"values": [False]},
},
"gpu_targets": ["gfx942"],
}
def print_section(title: str):
"""Print a section header"""
print(f"\n{'=' * 70}")
print(f" {title}")
print(f"{'=' * 70}\n")
def convert_to_dispatcher_config(
config: KernelConfig, arch: str = "gfx942"
) -> DispatcherKernelConfig:
"""Convert kernel_config_loader.KernelConfig to dispatcher KernelConfig"""
return DispatcherKernelConfig(
dtype_a=config.dtype_a,
dtype_b=config.dtype_b,
dtype_c=config.dtype_c,
dtype_acc=config.dtype_acc,
tile_m=config.tile.tile_m,
tile_n=config.tile.tile_n,
tile_k=config.tile.tile_k,
wave_m=config.tile.warp_m,
wave_n=config.tile.warp_n,
wave_k=config.tile.warp_k,
warp_m=config.tile.warp_tile_m,
warp_n=config.tile.warp_tile_n,
warp_k=config.tile.warp_tile_k,
pipeline=config.trait.pipeline,
scheduler=config.trait.scheduler,
epilogue=config.trait.epilogue,
pad_m=config.trait.pad_m,
pad_n=config.trait.pad_n,
pad_k=config.trait.pad_k,
gfx_arch=arch,
variant=config.variant,
)
def main():
parser = argparse.ArgumentParser(
description="JSON Kernel Configuration Import Example",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python3 11_json_import.py # Use embedded sample config
python3 11_json_import.py --config my.json # Load from file
python3 11_json_import.py --export-cpp # Generate C++ declarations
python3 11_json_import.py --validate # Validate configs against arch
""",
)
parser.add_argument(
"--config",
type=str,
help="Path to JSON configuration file (uses embedded sample if not provided)",
)
parser.add_argument(
"--export-cpp",
action="store_true",
help="Export kernel set as C++ DECL_KERNEL_SET",
)
parser.add_argument(
"--validate",
action="store_true",
help="Validate all configurations against arch filter",
)
parser.add_argument(
"--arch",
default="gfx942",
help="Target GPU architecture (default: gfx942)",
)
args = parser.parse_args()
reset_for_example()
print_section("Example 11: JSON Kernel Configuration Import")
# =========================================================================
# Step 1: Load configuration from JSON
# =========================================================================
print("Step 1: Load Kernel Configuration from JSON")
print("-" * 50)
if args.config:
config_path = Path(args.config)
if not config_path.exists():
print(f" ERROR: Config file not found: {config_path}")
return 1
print(f" Loading from: {config_path}")
config_set = load_kernel_configs(config_path)
else:
# Use embedded sample config
print(" Using embedded sample configuration")
# Write to temp file and load
temp_path = Path("/tmp/sample_gemm_config.json")
with open(temp_path, "w") as f:
json.dump(SAMPLE_JSON_CONFIG, f, indent=2)
config_set = load_kernel_configs(temp_path)
print(f"\n Kernel Set Name: {config_set.name}")
print(
f" Data Types: A={config_set.dtype_a}, B={config_set.dtype_b}, C={config_set.dtype_c}"
)
print(f" Layout: {config_set.layout}")
print(f" GPU Targets: {config_set.gpu_targets}")
print(f" Total Configurations: {config_set.config_count()}")
# =========================================================================
# Step 2: Display configuration details
# =========================================================================
print("\nStep 2: Configuration Details")
print("-" * 50)
print("\n Tile Configurations:")
print(f" tile_m: {config_set.tile_m_values}")
print(f" tile_n: {config_set.tile_n_values}")
print(f" tile_k: {config_set.tile_k_values}")
print(
f" warp (wave): {config_set.warp_m_values}x{config_set.warp_n_values}x{config_set.warp_k_values}"
)
print(
f" warp_tile: {config_set.warp_tile_m_values}x{config_set.warp_tile_n_values}x{config_set.warp_tile_k_values}"
)
print("\n Trait Configurations:")
print(f" pipeline: {config_set.pipeline_values}")
print(f" scheduler: {config_set.scheduler_values}")
print(f" epilogue: {config_set.epilogue_values}")
print(
f" padding: m={config_set.pad_m_values}, n={config_set.pad_n_values}, k={config_set.pad_k_values}"
)
# =========================================================================
# Step 3: Generate and display kernel names
# =========================================================================
print("\nStep 3: Generated Kernel Names")
print("-" * 50)
configs = list(config_set.generate_configs())
for i, config in enumerate(configs[:5]):
print(f" {i + 1}. {config.kernel_name()}")
if len(configs) > 5:
print(f" ... and {len(configs) - 5} more configurations")
# =========================================================================
# Step 4: Validate against arch filter (optional)
# =========================================================================
if args.validate:
print("\nStep 4: Architecture Validation")
print("-" * 50)
valid_count = 0
invalid_count = 0
for config in configs:
disp_config = convert_to_dispatcher_config(config, args.arch)
result = validate_kernel_config(disp_config)
if result.is_valid:
valid_count += 1
else:
invalid_count += 1
if invalid_count <= 3: # Show first 3 invalid
print(f"\n ✗ Invalid: {config.kernel_name()}")
for error in result.errors:
print(f" Error: {error}")
print("\n Validation Summary:")
print(f" ✓ Valid: {valid_count}")
print(f" ✗ Invalid: {invalid_count}")
print(f" Total: {len(configs)}")
# =========================================================================
# Step 5: Export to C++ (optional)
# =========================================================================
if args.export_cpp:
print("\nStep 5: C++ Export")
print("-" * 50)
print("\n // Generated DECL_KERNEL_SET from JSON config:")
print(" // " + "=" * 56)
cpp_code = generate_cpp_kernel_set_declaration(config_set)
for line in cpp_code.split("\n"):
print(f" {line}")
# =========================================================================
# Step 6: Use first config with dispatcher (demo)
# =========================================================================
print("\nStep 6: Dispatcher Integration Demo")
print("-" * 50)
if configs:
first_config = configs[0]
disp_config = convert_to_dispatcher_config(first_config, args.arch)
print(
f"\n Using first config: {first_config.tile.tile_m}x{first_config.tile.tile_n}x{first_config.tile.tile_k}"
)
setup = setup_gemm_dispatcher(
disp_config, registry_name="json_import", verbose=False
)
if setup.success:
print(" ✓ Dispatcher setup successful")
print(
f" Kernel header: {setup.kernel_header.name if setup.kernel_header else 'N/A'}"
)
else:
print(f" ⚠ Dispatcher setup: {setup.error}")
print(" (This is expected if kernels aren't generated)")
# =========================================================================
# Summary
# =========================================================================
print_section("Summary")
print(" JSON configuration allows easy kernel set customization:")
print(" - Define tile sizes and ranges")
print(" - Specify trait combinations (pipeline, scheduler, etc.)")
print(" - Target multiple GPU architectures")
print(" - Export to C++ DECL_KERNEL_SET for static compilation")
print()
print(" JSON Format (tile_engine compatible):")
print(' {"tile_config": {"tile_m": {"values": [128, 256]}, ...},')
print(' "trait_config": {"pipeline": {"values": ["compv4"]}, ...}}')
print()
print(" Usage:")
print(" config_set = load_kernel_configs('my_kernels.json')")
print(" for config in config_set.generate_configs():")
print(" # Use config for codegen or dispatcher setup")
cleanup_gemm()
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,299 @@
# GEMM Python Examples
CK Tile Dispatcher Python examples for GEMM (General Matrix Multiplication) operations.
> **Main Documentation**: [Dispatcher README](../../../README.md) | [Examples Overview](../../README.md)
## Quick Start
### Build Library
```bash
cd /path/to/composable_kernel/dispatcher
mkdir -p build && cd build
cmake .. \
-DCMAKE_PREFIX_PATH=/opt/rocm \
-DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-DBUILD_DISPATCHER_EXAMPLES=ON
# Build Python library (kernels generated automatically)
make dispatcher_gemm_lib -j$(nproc)
```
### Run Examples
```bash
cd /path/to/composable_kernel/dispatcher
python3 examples/gemm/python/01_basic_gemm.py
python3 examples/gemm/python/04_validation.py
python3 examples/gemm/python/07_stress_test.py
python3 examples/gemm/python/08_heuristics.py
```
## Examples
| Example | Description |
|---------|-------------|
| [01_basic_gemm.py](01_basic_gemm.py) | Basic GEMM with multi-kernel support |
| [02_batch_gemm.py](02_batch_gemm.py) | Batched GEMM operations |
| [03_benchmark.py](03_benchmark.py) | Performance benchmarking |
| [04_validation.py](04_validation.py) | CPU reference validation |
| [05_numpy_integration.py](05_numpy_integration.py) | NumPy array integration |
| [06_json_export.py](06_json_export.py) | Registry JSON export |
| [07_stress_test.py](07_stress_test.py) | Multi-kernel stress testing |
| [08_heuristics.py](08_heuristics.py) | Heuristic-based kernel selection |
| [09_multi_registry.py](09_multi_registry.py) | Multiple registries |
| [10_advanced_benchmark.py](10_advanced_benchmark.py) | Advanced benchmark with full control |
| [11_json_import.py](11_json_import.py) | Import kernels from JSON |
## Example Details
### 01_basic_gemm.py - Basic GEMM
Demonstrates the Python API with multi-kernel support:
```python
from ctypes_utils import KernelConfig, setup_gemm_dispatcher, print_kernel_config_table
# Define multiple kernel configurations
kernels = [
KernelConfig(
tile_m=128, tile_n=128, tile_k=32,
wave_m=2, wave_n=2, wave_k=1,
warp_tile_m=32, warp_tile_n=32, warp_tile_k=16,
pipeline="compv3", scheduler="intrawave"
),
KernelConfig(
tile_m=256, tile_n=256, tile_k=32,
wave_m=2, wave_n=2, wave_k=1,
warp_tile_m=32, warp_tile_n=32, warp_tile_k=16,
pipeline="compv4", scheduler="intrawave"
),
]
# Display configurations
print_kernel_config_table(kernels)
# Set up dispatcher with all kernels
lib, dispatcher, registry = setup_gemm_dispatcher(kernels)
# Run GEMM
elapsed_ms = run_gemm(lib, M, N, K, ...)
```
### 02_batch_gemm.py - Batch GEMM
Batched matrix multiplication:
- Multiple independent GEMM operations
- Batch dimension handling
### 03_benchmark.py - Benchmarking
Performance measurement:
- GPU timing
- TFLOPS calculation
- Multiple iterations
### 04_validation.py - Validation
Correctness verification:
- NumPy reference implementation
- Tolerance-based validation
- Error reporting
### 05_numpy_integration.py - NumPy Integration
Seamless NumPy integration:
- NumPy arrays to GPU buffers
- Results back to NumPy
- Automatic type conversion
### 06_json_export.py - JSON Export
Registry serialization for tool integration:
- Export kernel configurations
- Machine-readable format
### 07_stress_test.py - Stress Testing
Comprehensive multi-kernel stress testing:
```python
from ctypes_utils import KernelConfig, setup_gemm_dispatcher, print_kernel_config_table
# Define 48 unique kernel configurations
kernels = [
KernelConfig(tile_m=128, tile_n=128, tile_k=32, pipeline="compv3", ...),
KernelConfig(tile_m=256, tile_n=256, tile_k=32, pipeline="compv4", ...),
KernelConfig(tile_m=128, tile_n=256, tile_k=64, pipeline="compv3", ...),
# ... many more configurations
]
# Test each kernel
for i, kernel in enumerate(kernels):
lib, dispatcher, registry = setup_gemm_dispatcher([kernel])
result = run_and_validate(lib, M, N, K, seed=42 + i) # Different seed per kernel
print(f"Kernel {i}: {result.max_err:.6e} {'PASS' if result.passed else 'FAIL'}")
```
**Features:**
- 48 unique kernel configurations
- Various tile sizes, pipelines, and schedulers
- Per-kernel validation with unique random seeds
- Performance reporting
### 08_heuristics.py - Heuristic Selection
Custom kernel selection based on problem characteristics:
```python
# Define kernel pools for different strategies
SMALL_KERNELS = [KernelConfig(tile_m=64, tile_n=64, ...), ...]
LARGE_KERNELS = [KernelConfig(tile_m=256, tile_n=256, ...), ...]
COMPUTE_KERNELS = [KernelConfig(pipeline="compv4", ...), ...]
MEMORY_KERNELS = [KernelConfig(pipeline="compv3", ...), ...]
# Size-based heuristic
def size_based_heuristic(M, N, K):
if M * N < 512 * 512:
return SMALL_KERNELS
else:
return LARGE_KERNELS
# Strategy-based selection
def compute_strategy():
return COMPUTE_KERNELS # Optimized for compute-bound problems
def memory_strategy():
return MEMORY_KERNELS # Optimized for memory-bound problems
# Test different strategies
for strategy in [size_based_heuristic, compute_strategy, memory_strategy]:
kernels = strategy(M, N, K)
lib, dispatcher, registry = setup_gemm_dispatcher(kernels)
elapsed_ms = run_gemm(lib, M, N, K, ...)
```
**Features:**
- 24 kernel configurations across 6 categories
- Size-based heuristic (small vs large)
- Optimization strategies (compute, memory, latency)
- Performance comparison across strategies
### 09_multi_registry.py - Multiple Registries
Separate registries for different workloads:
- Compute-optimized registry
- Latency-optimized registry
- Dynamic registry selection
### 10_advanced_benchmark.py - Advanced Benchmark
Full control over benchmark parameters:
- Warmup iterations
- Benchmark iterations
- Statistical analysis
### 11_json_import.py - JSON Import
Import kernel configurations from JSON:
- External configuration files
- Dynamic kernel loading
## Utility Module: ctypes_utils.py
```python
from ctypes_utils import (
KernelConfig, # Single kernel configuration
setup_gemm_dispatcher, # Set up dispatcher with kernels
print_kernel_config_table, # Display kernel configurations
Dispatcher, # High-level dispatcher
Registry, # Kernel registry
Validator, # Validation utilities
)
```
### KernelConfig
```python
config = KernelConfig(
# Tile sizes
tile_m=256, tile_n=256, tile_k=32,
# Wave configuration
wave_m=2, wave_n=2, wave_k=1,
# Warp tile sizes
warp_tile_m=32, warp_tile_n=32, warp_tile_k=16,
# Pipeline and scheduler
pipeline="compv4", # "compv3" or "compv4"
scheduler="intrawave", # "intrawave" or "interwave"
# Optional
epilogue="default",
padding=True,
double_buffer=True,
)
```
### setup_gemm_dispatcher
```python
# Single kernel
lib, dispatcher, registry = setup_gemm_dispatcher(config)
# Multiple kernels
lib, dispatcher, registry = setup_gemm_dispatcher([config1, config2, ...])
# With auto-rebuild
lib, dispatcher, registry = setup_gemm_dispatcher(config, auto_rebuild=True)
```
### print_kernel_config_table
```python
kernels = [config1, config2, config3]
print_kernel_config_table(kernels)
# Output:
# +----+-------+-------+-------+--------+-----------+
# | # | Tile | Wave | Warp | Pipe | Scheduler |
# +----+-------+-------+-------+--------+-----------+
# | 1 | 128x128x32 | 2x2x1 | 32x32x16 | compv3 | intrawave |
# | 2 | 256x256x32 | 2x2x1 | 32x32x16 | compv4 | intrawave |
# | 3 | 128x256x64 | 2x2x1 | 32x32x16 | compv3 | interwave |
# +----+-------+-------+-------+--------+-----------+
```
### GPU Memory Management
```python
import ctypes
import numpy as np
# Load HIP library
hip = ctypes.CDLL("libamdhip64.so")
# Allocate GPU memory
gpu_ptr = ctypes.c_void_p()
hip.hipMalloc(ctypes.byref(gpu_ptr), size_in_bytes)
# Copy to GPU (1 = hipMemcpyHostToDevice)
hip.hipMemcpy(gpu_ptr, host_array.ctypes.data, size, 1)
# Copy back (2 = hipMemcpyDeviceToHost)
hip.hipMemcpy(host_array.ctypes.data, gpu_ptr, size, 2)
# Free
hip.hipFree(gpu_ptr)
```
## Performance Testing
Test compilation performance with different kernel counts:
```bash
# Test with 10 kernels (~15s compile time)
python3 01_basic_gemm.py --num-kernels 10
# Test with 20 kernels (~25s compile time)
python3 01_basic_gemm.py --num-kernels 20
# Test with 48 kernels (~50s compile time)
python3 01_basic_gemm.py --num-kernels 48
```
Compilation time scales roughly linearly with kernel count.
## Related Documentation
- [C++ GEMM Examples](../cpp/README.md)
- [Python Conv Examples](../../conv/python/README.md)
- [Main Dispatcher README](../../../README.md)

View File

@@ -0,0 +1,80 @@
{
"registry": "export_demo",
"kernel_count": 3,
"kernels": [
{
"tile": "128x128x32",
"dtypes": {
"A": "fp16",
"B": "fp16",
"C": "fp16"
},
"layout": "rcr",
"pipeline": "compv4",
"target": "gfx942"
},
{
"tile": "256x256x64",
"dtypes": {
"A": "fp16",
"B": "fp16",
"C": "fp16"
},
"layout": "rcr",
"pipeline": "compv4",
"target": "gfx942"
},
{
"tile": "64x64x32",
"dtypes": {
"A": "fp16",
"B": "fp16",
"C": "fp16"
},
"layout": "rcr",
"pipeline": "compv4",
"target": "gfx942"
}
],
"cpp_registry": {
"metadata": {
"timestamp": "Dec 4 2025 06:23:15",
"total_kernels": 1,
"export_version": "1.0",
"dispatcher_version": "1.0.0"
},
"statistics": {
"by_datatype": {},
"by_pipeline": {},
"by_scheduler": {}
},
"kernels": [
{
"identifier": "128x128x32_2x2x1_32x32x16_nopers",
"name": "gemm_fp16_rcrr_compv4_cshuffle_intrawave_False_False_False_False_128x128x32_2x2x1_32x32x16",
"algorithm": {
"tile_shape": {
"m": 128,
"n": 128,
"k": 32
},
"wave_shape": {
"m": 2,
"n": 2,
"k": 1
},
"warp_tile_shape": {
"m": 32,
"n": 32,
"k": 16
},
"block_size": 256,
"persistent": false,
"double_buffer": true,
"preshuffle": false,
"transpose_c": false
}
}
]
}
}

Some files were not shown because too many files have changed in this diff Show More